From beb110a4763bc4e5c915b28ddf7bba78eff50a08 Mon Sep 17 00:00:00 2001 From: liangzelang Date: Fri, 10 Sep 2021 16:20:16 +0800 Subject: [PATCH] fix ctx device name different with real device name bug. --- mindspore/ccsrc/backend/session/ascend_session.cc | 2 +- .../ccsrc/runtime/device/ascend/ascend_device_address.cc | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 61b62c39004..750ae529476 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -459,7 +459,7 @@ void AscendSession::LoadInputData(const std::shared_ptr &kernel_gra tensor->set_sync_status(kNoNeedSync); } if (device_memcpy_nums > 0) { - auto runtime_instance = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime(); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); auto compute_stream = runtime_instance->compute_stream(); auto model_stream = runtime_instance->GetModelStream(kernel_graph->graph_id()); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc index 44cfc68d63c..3811f69d19e 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -98,7 +98,10 @@ bool AsyncMemcpy(void *dst, uint64_t dst_size, const void *src, uint64_t src_siz MS_LOG(INFO) << "dst addr is same with src addr, no need memcpy data."; return true; } - auto runtime_instance = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime(); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + auto device_id = ms_context->get_param(MS_CTX_DEVICE_ID); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); MS_EXCEPTION_IF_NULL(runtime_instance); auto ret = runtime_instance->MemcpyAsync(dst, src, src_size, static_cast(RT_MEMCPY_DEVICE_TO_DEVICE)); if (!ret) {