add getptr for memreuse

This commit is contained in:
kswang 2020-04-01 21:37:02 +08:00
parent e2df848597
commit 04be6a37f0
4 changed files with 40 additions and 30 deletions

View File

@ -42,7 +42,6 @@ KernelRuntime::~KernelRuntime() {
#ifdef ENABLE_DUMP_E2E #ifdef ENABLE_DUMP_E2E
dump_conf_ptr_ = nullptr; dump_conf_ptr_ = nullptr;
#endif #endif
reuse_mem_base_ = nullptr;
mem_reuse_util_ptr_ = nullptr; mem_reuse_util_ptr_ = nullptr;
} }
@ -476,9 +475,9 @@ void KernelRuntime::ReuseAssignDynamicMemory(session::KernelGraph *graph) {
bestfit_mem_reuse->Reuse(mem_reuse_util_ptr.get()); bestfit_mem_reuse->Reuse(mem_reuse_util_ptr.get());
size_t total_allocated_size = bestfit_mem_reuse->GetAllocatedSize(); size_t total_allocated_size = bestfit_mem_reuse->GetAllocatedSize();
MS_LOG(INFO) << "TotalReuseDynamicSize [" << total_allocated_size << "]"; MS_LOG(INFO) << "TotalReuseDynamicSize [" << total_allocated_size << "]";
auto base_ptr = MallocDynamicMem(total_allocated_size, false);
reuse_mem_base_ = base_ptr;
mem_reuse_util_ptr_ = mem_reuse_util_ptr; mem_reuse_util_ptr_ = mem_reuse_util_ptr;
auto base_ptr = MallocDynamicMem(total_allocated_size, false);
mem_reuse_util_ptr_->set_mem_base(base_ptr);
auto &kernels = graph->execution_order(); auto &kernels = graph->execution_order();
for (auto &kernel : kernels) { for (auto &kernel : kernels) {
AssignNodeOutputMem(kReuseDynamicMem, kernel, kGetAllOuts); AssignNodeOutputMem(kReuseDynamicMem, kernel, kGetAllOuts);
@ -488,22 +487,13 @@ void KernelRuntime::ReuseAssignDynamicMemory(session::KernelGraph *graph) {
void KernelRuntime::AssignReuseWorkSpaceMem(const AnfNodePtr &node) { void KernelRuntime::AssignReuseWorkSpaceMem(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto key = node.get();
auto kernel_mod = AnfAlgo::GetKernelMod(node); auto kernel_mod = AnfAlgo::GetKernelMod(node);
MS_EXCEPTION_IF_NULL(kernel_mod); MS_EXCEPTION_IF_NULL(kernel_mod);
size_t index = 0; size_t index = 0;
auto iter = mem_reuse_util_ptr_->kernel_workspace_refs_.find(key);
for (auto &size : kernel_mod->GetWorkspaceSizeList()) { for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
if (iter != mem_reuse_util_ptr_->kernel_workspace_refs_.end()) { auto wk_ptr = mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index);
if (index >= iter->second.size()) { AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(wk_ptr, size, "", kTypeUnknown), index, node.get());
MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() index++;
<< "]";
}
auto wk_ref = iter->second[index];
auto wk_ptr = reuse_mem_base_ + wk_ref->offset_;
AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(wk_ptr, size, "", kTypeUnknown), index, node.get());
index++;
}
} }
} }
@ -554,18 +544,7 @@ uint8_t *KernelRuntime::CalDeviceMem(const AnfNodePtr &node, size_t size, int fl
} else if (flag == kDynamicMem) { } else if (flag == kDynamicMem) {
ptr = MallocDynamicMem(size, false); ptr = MallocDynamicMem(size, false);
} else if (flag == kReuseDynamicMem) { } else if (flag == kReuseDynamicMem) {
auto key = node.get(); ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index);
auto iter = mem_reuse_util_ptr_->kernel_output_refs_.find(key);
if (iter != mem_reuse_util_ptr_->kernel_output_refs_.end()) {
// private member form KernelRuntime
memreuse::KernelRefCountPtr kernel_ref_count_ptr = mem_reuse_util_ptr_->kernel_output_refs_[key][index];
if (kernel_ref_count_ptr == nullptr) {
return ptr;
}
ptr = reuse_mem_base_ + kernel_ref_count_ptr->offset_;
} else {
MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in kernel_output_refs";
}
} }
return ptr; return ptr;
} }

View File

@ -128,9 +128,6 @@ class KernelRuntime {
size_t total_static_size_ = 0; size_t total_static_size_ = 0;
size_t total_dynamic_size_ = 0; size_t total_dynamic_size_ = 0;
MemReuseUtilPtr mem_reuse_util_ptr_{nullptr}; MemReuseUtilPtr mem_reuse_util_ptr_{nullptr};
private:
uint8_t *reuse_mem_base_{nullptr};
}; };
using KernelRuntimePtr = std::shared_ptr<KernelRuntime>; using KernelRuntimePtr = std::shared_ptr<KernelRuntime>;
} // namespace device } // namespace device

View File

@ -316,5 +316,35 @@ void MemReuseUtil::SetAllInfo(KernelGraph *graph) {
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph); MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph);
#endif #endif
} }
uint8_t *MemReuseUtil::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const {
auto key = node.get();
auto iter = kernel_output_refs_.find(key);
uint8_t *ptr = nullptr;
if (iter != kernel_output_refs_.end()) {
if (index >= iter->second.size()) {
MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]";
}
auto output_ref = iter->second[index];
ptr = mem_base_ + output_ref->offset_;
} else {
MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in kernel_output_refs";
}
return ptr;
}
uint8_t *MemReuseUtil::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const {
auto key = node.get();
auto iter = kernel_workspace_refs_.find(key);
uint8_t *ptr = nullptr;
if (iter != kernel_workspace_refs_.end()) {
if (index >= iter->second.size()) {
MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]";
}
auto wk_ref = iter->second[index];
ptr = mem_base_ + wk_ref->offset_;
}
return ptr;
}
} // namespace memreuse } // namespace memreuse
} // namespace mindspore } // namespace mindspore

View File

@ -76,6 +76,9 @@ class MemReuseUtil {
void set_kernel_def_ptr_list(const KernelDefPtrMaps &kernel_def_ptr_list) { void set_kernel_def_ptr_list(const KernelDefPtrMaps &kernel_def_ptr_list) {
kernel_def_ptr_list_ = kernel_def_ptr_list; kernel_def_ptr_list_ = kernel_def_ptr_list;
} }
void set_mem_base(uint8_t *mem_base) { mem_base_ = mem_base; }
uint8_t *GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const;
uint8_t *GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const;
private: private:
int util_index_; int util_index_;
@ -88,6 +91,7 @@ class MemReuseUtil {
size_t total_dy_size_ = 0; size_t total_dy_size_ = 0;
size_t total_workspace_size_ = 0; size_t total_workspace_size_ = 0;
size_t total_reuseworkspace_size_ = 0; size_t total_reuseworkspace_size_ = 0;
uint8_t *mem_base_{nullptr};
}; };
using MemReuseUtilPtr = std::shared_ptr<MemReuseUtil>; using MemReuseUtilPtr = std::shared_ptr<MemReuseUtil>;
} // namespace memreuse } // namespace memreuse