diff --git a/mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.cc b/mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.cc index bd391a23f51..77d6a164c51 100644 --- a/mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.cc +++ b/mindspore/ccsrc/cxx_api/graph/ms/ms_graph_impl.cc @@ -116,15 +116,6 @@ Status MsGraphImpl::FinalizeEnv() { MS_LOG_INFO << "Start finalize env"; session::ExecutorManager::Instance().Clear(); device::KernelRuntimeManager::Instance().ClearRuntimeResource(); - auto ms_context = MsContext::GetInstance(); - if (ms_context == nullptr) { - MS_LOG(ERROR) << "Get Context failed!"; - return FAILED; - } - if (!context::CloseTsd(ms_context)) { - MS_LOG(ERROR) << "CloseTsd failed!"; - return FAILED; - } init_flag_ = false; MS_LOG(INFO) << "End finalize env"; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index b70b7dbb41d..363917d4f6f 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -245,10 +245,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - auto ret = rtSetDevice(context_ptr->get_param(MS_CTX_DEVICE_ID)); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast(ret) << "]"; - } + uint32_t device_id = context_ptr->get_param(MS_CTX_DEVICE_ID); if (mem_manager_ != nullptr) { mem_manager_->FreeDeviceMemory(); @@ -256,7 +253,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() { (void)DestroySingleOpHccl(); (void)DestroyHccl(); - (void)ResetDevice(); + (void)ResetDevice(device_id); (void)ProfilingManager::GetInstance().StopProfiling(); MS_LOG(INFO) << "Ascend finalize end"; } @@ -729,7 +726,7 @@ bool AscendKernelRuntime::InitDevice() { return true; } -bool AscendKernelRuntime::ResetDevice() { +bool AscendKernelRuntime::ResetDevice(uint32_t device_id) { InnerSetContext(); if (stream_ != nullptr) { auto ret = rtStreamDestroy(stream_); @@ -747,6 +744,10 @@ bool AscendKernelRuntime::ResetDevice() { rt_context_ = nullptr; } + auto ret = rtDeviceReset(device_id); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]"; + } // set to nullptr as its not created, only bounded to existing context rt_context_hccl_ = nullptr; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index d77744a6143..68d3d990064 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -67,7 +67,7 @@ class AscendKernelRuntime : public KernelRuntime { private: bool InitDevice(); - bool ResetDevice(); + bool ResetDevice(uint32_t device_id); bool HcclInit(); bool NeedDestroyHccl(); bool DestroyHccl(); diff --git a/tests/ut/cpp/stub/runtime/runtime_stub.cc b/tests/ut/cpp/stub/runtime/runtime_stub.cc index b6528dc0657..19c6e1a8bcd 100644 --- a/tests/ut/cpp/stub/runtime/runtime_stub.cc +++ b/tests/ut/cpp/stub/runtime/runtime_stub.cc @@ -37,6 +37,8 @@ rtError_t rtGetDeviceCount(int32_t *count) { return RT_ERROR_NONE; } rtError_t rtSetDevice(int32_t device) { return RT_ERROR_NONE; } +rtError_t rtDeviceReset(int32_t device) { return RT_ERROR_NONE; } + rtError_t rtCtxCreate(rtContext_t *ctx, uint32_t flags, int32_t device) { return RT_ERROR_NONE; } rtError_t rtCtxSetCurrent(rtContext_t ctx) { return RT_ERROR_NONE; }