diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index fd240d41cbb..32e2d51c252 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -1014,6 +1014,7 @@ void AscendSession::AssignStaticMemory(NotNull graph, // assign static memory for parameters auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->ClearGlobalIdleMem(); runtime_instance->AssignStaticMemoryInput(graph.get().get()); runtime_instance->AssignStaticMemoryValueNode(graph.get().get()); for (auto &child_graph : graph->child_graph_order()) { diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 8b9f267bded..7b0f2621cf8 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -155,6 +155,8 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std } } +void AscendKernelRuntime::ClearGlobalIdleMem() { mem_manager_->ClearGlobalIdleMem(); } + bool AscendKernelRuntime::NeedDestroyHccl() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index f23da565ff4..8afe6a39ca2 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -49,6 +49,7 @@ class AscendKernelRuntime : public KernelRuntime { void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector &inputs, const std::unordered_set &value_nodes, const std::vector &execution_order) override; + void ClearGlobalIdleMem() override; bool SyncStream() override; protected: diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc index 93e38278f14..bf2f0316f86 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc @@ -77,6 +77,8 @@ void AscendMemoryManager::ResetDynamicMemory() { AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); } +void AscendMemoryManager::ClearGlobalIdleMem() { AscendMemoryPool::GetInstance().ResetIdleMemBuf(); } + void *AscendMemoryManager::MallocMemFromMemPool(size_t size) { auto align_size = GetCommonAlignSize(size); return AscendMemoryPool::GetInstance().AllocTensorMem(align_size); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h index fc684f3fd85..77812b489cb 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h @@ -28,6 +28,7 @@ class AscendMemoryManager : public MemoryManager { void MallocDeviceMemory() override; void FreeDeviceMemory() override; void ResetDynamicMemory() override; + void ClearGlobalIdleMem() override; void *MallocMemFromMemPool(size_t size) override; protected: diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index 636b5c88840..1906e778f37 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -74,6 +74,7 @@ class KernelRuntime { const std::unordered_set &value_nodes, const std::vector &execution_order); virtual bool SyncStream() = 0; + virtual void ClearGlobalIdleMem() {} #ifdef ENABLE_DUMP_E2E DumpConfPtr GetDumpConf(); diff --git a/mindspore/ccsrc/runtime/device/memory_manager.h b/mindspore/ccsrc/runtime/device/memory_manager.h index cb045f8d274..ad2142d1a22 100644 --- a/mindspore/ccsrc/runtime/device/memory_manager.h +++ b/mindspore/ccsrc/runtime/device/memory_manager.h @@ -39,6 +39,7 @@ class MemoryManager { total_dynamic_size_ = 0; dynamic_mem_offset_ = 0; } + virtual void ClearGlobalIdleMem() {} void MallocReusedDynamicMem(const session::KernelGraph *graph); uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size,