diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index 99f5d491aca..878fe4a7f86 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -42,7 +42,6 @@ KernelRuntime::~KernelRuntime() { #ifdef ENABLE_DUMP_E2E dump_conf_ptr_ = nullptr; #endif - reuse_mem_base_ = nullptr; mem_reuse_util_ptr_ = nullptr; } @@ -476,9 +475,9 @@ void KernelRuntime::ReuseAssignDynamicMemory(session::KernelGraph *graph) { bestfit_mem_reuse->Reuse(mem_reuse_util_ptr.get()); size_t total_allocated_size = bestfit_mem_reuse->GetAllocatedSize(); MS_LOG(INFO) << "TotalReuseDynamicSize [" << total_allocated_size << "]"; - auto base_ptr = MallocDynamicMem(total_allocated_size, false); - reuse_mem_base_ = base_ptr; mem_reuse_util_ptr_ = mem_reuse_util_ptr; + auto base_ptr = MallocDynamicMem(total_allocated_size, false); + mem_reuse_util_ptr_->set_mem_base(base_ptr); auto &kernels = graph->execution_order(); for (auto &kernel : kernels) { AssignNodeOutputMem(kReuseDynamicMem, kernel, kGetAllOuts); @@ -488,22 +487,13 @@ void KernelRuntime::ReuseAssignDynamicMemory(session::KernelGraph *graph) { void KernelRuntime::AssignReuseWorkSpaceMem(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto key = node.get(); auto kernel_mod = AnfAlgo::GetKernelMod(node); MS_EXCEPTION_IF_NULL(kernel_mod); size_t index = 0; - auto iter = mem_reuse_util_ptr_->kernel_workspace_refs_.find(key); for (auto &size : kernel_mod->GetWorkspaceSizeList()) { - if (iter != mem_reuse_util_ptr_->kernel_workspace_refs_.end()) { - if (index >= iter->second.size()) { - MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() - << "]"; - } - auto wk_ref = iter->second[index]; - auto wk_ptr = reuse_mem_base_ + wk_ref->offset_; - AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(wk_ptr, size, "", kTypeUnknown), index, node.get()); - index++; - } + auto wk_ptr = mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index); + AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(wk_ptr, size, "", kTypeUnknown), index, node.get()); + index++; } } @@ -554,18 +544,7 @@ uint8_t *KernelRuntime::CalDeviceMem(const AnfNodePtr &node, size_t size, int fl } else if (flag == kDynamicMem) { ptr = MallocDynamicMem(size, false); } else if (flag == kReuseDynamicMem) { - auto key = node.get(); - auto iter = mem_reuse_util_ptr_->kernel_output_refs_.find(key); - if (iter != mem_reuse_util_ptr_->kernel_output_refs_.end()) { - // private member form KernelRuntime - memreuse::KernelRefCountPtr kernel_ref_count_ptr = mem_reuse_util_ptr_->kernel_output_refs_[key][index]; - if (kernel_ref_count_ptr == nullptr) { - return ptr; - } - ptr = reuse_mem_base_ + kernel_ref_count_ptr->offset_; - } else { - MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in kernel_output_refs"; - } + ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index); } return ptr; } diff --git a/mindspore/ccsrc/device/kernel_runtime.h b/mindspore/ccsrc/device/kernel_runtime.h index afdb45a6989..ac9a56ed4d8 100644 --- a/mindspore/ccsrc/device/kernel_runtime.h +++ b/mindspore/ccsrc/device/kernel_runtime.h @@ -128,9 +128,6 @@ class KernelRuntime { size_t total_static_size_ = 0; size_t total_dynamic_size_ = 0; MemReuseUtilPtr mem_reuse_util_ptr_{nullptr}; - - private: - uint8_t *reuse_mem_base_{nullptr}; }; using KernelRuntimePtr = std::shared_ptr; } // namespace device diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc index 0db3c35196c..2113fec6539 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc @@ -316,5 +316,35 @@ void MemReuseUtil::SetAllInfo(KernelGraph *graph) { MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph); #endif } + +uint8_t *MemReuseUtil::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const { + auto key = node.get(); + auto iter = kernel_output_refs_.find(key); + uint8_t *ptr = nullptr; + if (iter != kernel_output_refs_.end()) { + if (index >= iter->second.size()) { + MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]"; + } + auto output_ref = iter->second[index]; + ptr = mem_base_ + output_ref->offset_; + } else { + MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in kernel_output_refs"; + } + return ptr; +} + +uint8_t *MemReuseUtil::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const { + auto key = node.get(); + auto iter = kernel_workspace_refs_.find(key); + uint8_t *ptr = nullptr; + if (iter != kernel_workspace_refs_.end()) { + if (index >= iter->second.size()) { + MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]"; + } + auto wk_ref = iter->second[index]; + ptr = mem_base_ + wk_ref->offset_; + } + return ptr; +} } // namespace memreuse } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h index 6ecd222688c..cae0e4565f8 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h @@ -76,6 +76,9 @@ class MemReuseUtil { void set_kernel_def_ptr_list(const KernelDefPtrMaps &kernel_def_ptr_list) { kernel_def_ptr_list_ = kernel_def_ptr_list; } + void set_mem_base(uint8_t *mem_base) { mem_base_ = mem_base; } + uint8_t *GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const; + uint8_t *GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const; private: int util_index_; @@ -88,6 +91,7 @@ class MemReuseUtil { size_t total_dy_size_ = 0; size_t total_workspace_size_ = 0; size_t total_reuseworkspace_size_ = 0; + uint8_t *mem_base_{nullptr}; }; using MemReuseUtilPtr = std::shared_ptr; } // namespace memreuse