!83 add get ptr for memreuseutil
Merge pull request !83 from kisnwang/add-getptr-for-memreuseutil
This commit is contained in:
commit
7341d7ec1e
|
@ -42,7 +42,6 @@ KernelRuntime::~KernelRuntime() {
|
|||
#ifdef ENABLE_DUMP_E2E
|
||||
dump_conf_ptr_ = nullptr;
|
||||
#endif
|
||||
reuse_mem_base_ = nullptr;
|
||||
mem_reuse_util_ptr_ = nullptr;
|
||||
}
|
||||
|
||||
|
@ -476,9 +475,9 @@ void KernelRuntime::ReuseAssignDynamicMemory(session::KernelGraph *graph) {
|
|||
bestfit_mem_reuse->Reuse(mem_reuse_util_ptr.get());
|
||||
size_t total_allocated_size = bestfit_mem_reuse->GetAllocatedSize();
|
||||
MS_LOG(INFO) << "TotalReuseDynamicSize [" << total_allocated_size << "]";
|
||||
auto base_ptr = MallocDynamicMem(total_allocated_size, false);
|
||||
reuse_mem_base_ = base_ptr;
|
||||
mem_reuse_util_ptr_ = mem_reuse_util_ptr;
|
||||
auto base_ptr = MallocDynamicMem(total_allocated_size, false);
|
||||
mem_reuse_util_ptr_->set_mem_base(base_ptr);
|
||||
auto &kernels = graph->execution_order();
|
||||
for (auto &kernel : kernels) {
|
||||
AssignNodeOutputMem(kReuseDynamicMem, kernel, kGetAllOuts);
|
||||
|
@ -488,22 +487,13 @@ void KernelRuntime::ReuseAssignDynamicMemory(session::KernelGraph *graph) {
|
|||
|
||||
void KernelRuntime::AssignReuseWorkSpaceMem(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto key = node.get();
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
size_t index = 0;
|
||||
auto iter = mem_reuse_util_ptr_->kernel_workspace_refs_.find(key);
|
||||
for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
|
||||
if (iter != mem_reuse_util_ptr_->kernel_workspace_refs_.end()) {
|
||||
if (index >= iter->second.size()) {
|
||||
MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size()
|
||||
<< "]";
|
||||
}
|
||||
auto wk_ref = iter->second[index];
|
||||
auto wk_ptr = reuse_mem_base_ + wk_ref->offset_;
|
||||
AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(wk_ptr, size, "", kTypeUnknown), index, node.get());
|
||||
index++;
|
||||
}
|
||||
auto wk_ptr = mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index);
|
||||
AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(wk_ptr, size, "", kTypeUnknown), index, node.get());
|
||||
index++;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -554,18 +544,7 @@ uint8_t *KernelRuntime::CalDeviceMem(const AnfNodePtr &node, size_t size, int fl
|
|||
} else if (flag == kDynamicMem) {
|
||||
ptr = MallocDynamicMem(size, false);
|
||||
} else if (flag == kReuseDynamicMem) {
|
||||
auto key = node.get();
|
||||
auto iter = mem_reuse_util_ptr_->kernel_output_refs_.find(key);
|
||||
if (iter != mem_reuse_util_ptr_->kernel_output_refs_.end()) {
|
||||
// private member form KernelRuntime
|
||||
memreuse::KernelRefCountPtr kernel_ref_count_ptr = mem_reuse_util_ptr_->kernel_output_refs_[key][index];
|
||||
if (kernel_ref_count_ptr == nullptr) {
|
||||
return ptr;
|
||||
}
|
||||
ptr = reuse_mem_base_ + kernel_ref_count_ptr->offset_;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in kernel_output_refs";
|
||||
}
|
||||
ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index);
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
|
|
|
@ -128,9 +128,6 @@ class KernelRuntime {
|
|||
size_t total_static_size_ = 0;
|
||||
size_t total_dynamic_size_ = 0;
|
||||
MemReuseUtilPtr mem_reuse_util_ptr_{nullptr};
|
||||
|
||||
private:
|
||||
uint8_t *reuse_mem_base_{nullptr};
|
||||
};
|
||||
using KernelRuntimePtr = std::shared_ptr<KernelRuntime>;
|
||||
} // namespace device
|
||||
|
|
|
@ -316,5 +316,35 @@ void MemReuseUtil::SetAllInfo(KernelGraph *graph) {
|
|||
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph);
|
||||
#endif
|
||||
}
|
||||
|
||||
uint8_t *MemReuseUtil::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const {
|
||||
auto key = node.get();
|
||||
auto iter = kernel_output_refs_.find(key);
|
||||
uint8_t *ptr = nullptr;
|
||||
if (iter != kernel_output_refs_.end()) {
|
||||
if (index >= iter->second.size()) {
|
||||
MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]";
|
||||
}
|
||||
auto output_ref = iter->second[index];
|
||||
ptr = mem_base_ + output_ref->offset_;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in kernel_output_refs";
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
|
||||
uint8_t *MemReuseUtil::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const {
|
||||
auto key = node.get();
|
||||
auto iter = kernel_workspace_refs_.find(key);
|
||||
uint8_t *ptr = nullptr;
|
||||
if (iter != kernel_workspace_refs_.end()) {
|
||||
if (index >= iter->second.size()) {
|
||||
MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]";
|
||||
}
|
||||
auto wk_ref = iter->second[index];
|
||||
ptr = mem_base_ + wk_ref->offset_;
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
} // namespace memreuse
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -76,6 +76,9 @@ class MemReuseUtil {
|
|||
void set_kernel_def_ptr_list(const KernelDefPtrMaps &kernel_def_ptr_list) {
|
||||
kernel_def_ptr_list_ = kernel_def_ptr_list;
|
||||
}
|
||||
void set_mem_base(uint8_t *mem_base) { mem_base_ = mem_base; }
|
||||
uint8_t *GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const;
|
||||
uint8_t *GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const;
|
||||
|
||||
private:
|
||||
int util_index_;
|
||||
|
@ -88,6 +91,7 @@ class MemReuseUtil {
|
|||
size_t total_dy_size_ = 0;
|
||||
size_t total_workspace_size_ = 0;
|
||||
size_t total_reuseworkspace_size_ = 0;
|
||||
uint8_t *mem_base_{nullptr};
|
||||
};
|
||||
using MemReuseUtilPtr = std::shared_ptr<MemReuseUtil>;
|
||||
} // namespace memreuse
|
||||
|
|
Loading…
Reference in New Issue