forked from mindspore-Ecosystem/mindspore
add getptr for memreuse
This commit is contained in:
parent
e2df848597
commit
04be6a37f0
|
@ -42,7 +42,6 @@ KernelRuntime::~KernelRuntime() {
|
||||||
#ifdef ENABLE_DUMP_E2E
|
#ifdef ENABLE_DUMP_E2E
|
||||||
dump_conf_ptr_ = nullptr;
|
dump_conf_ptr_ = nullptr;
|
||||||
#endif
|
#endif
|
||||||
reuse_mem_base_ = nullptr;
|
|
||||||
mem_reuse_util_ptr_ = nullptr;
|
mem_reuse_util_ptr_ = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -476,9 +475,9 @@ void KernelRuntime::ReuseAssignDynamicMemory(session::KernelGraph *graph) {
|
||||||
bestfit_mem_reuse->Reuse(mem_reuse_util_ptr.get());
|
bestfit_mem_reuse->Reuse(mem_reuse_util_ptr.get());
|
||||||
size_t total_allocated_size = bestfit_mem_reuse->GetAllocatedSize();
|
size_t total_allocated_size = bestfit_mem_reuse->GetAllocatedSize();
|
||||||
MS_LOG(INFO) << "TotalReuseDynamicSize [" << total_allocated_size << "]";
|
MS_LOG(INFO) << "TotalReuseDynamicSize [" << total_allocated_size << "]";
|
||||||
auto base_ptr = MallocDynamicMem(total_allocated_size, false);
|
|
||||||
reuse_mem_base_ = base_ptr;
|
|
||||||
mem_reuse_util_ptr_ = mem_reuse_util_ptr;
|
mem_reuse_util_ptr_ = mem_reuse_util_ptr;
|
||||||
|
auto base_ptr = MallocDynamicMem(total_allocated_size, false);
|
||||||
|
mem_reuse_util_ptr_->set_mem_base(base_ptr);
|
||||||
auto &kernels = graph->execution_order();
|
auto &kernels = graph->execution_order();
|
||||||
for (auto &kernel : kernels) {
|
for (auto &kernel : kernels) {
|
||||||
AssignNodeOutputMem(kReuseDynamicMem, kernel, kGetAllOuts);
|
AssignNodeOutputMem(kReuseDynamicMem, kernel, kGetAllOuts);
|
||||||
|
@ -488,22 +487,13 @@ void KernelRuntime::ReuseAssignDynamicMemory(session::KernelGraph *graph) {
|
||||||
|
|
||||||
void KernelRuntime::AssignReuseWorkSpaceMem(const AnfNodePtr &node) {
|
void KernelRuntime::AssignReuseWorkSpaceMem(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto key = node.get();
|
|
||||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
size_t index = 0;
|
size_t index = 0;
|
||||||
auto iter = mem_reuse_util_ptr_->kernel_workspace_refs_.find(key);
|
|
||||||
for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
|
for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
|
||||||
if (iter != mem_reuse_util_ptr_->kernel_workspace_refs_.end()) {
|
auto wk_ptr = mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index);
|
||||||
if (index >= iter->second.size()) {
|
AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(wk_ptr, size, "", kTypeUnknown), index, node.get());
|
||||||
MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size()
|
index++;
|
||||||
<< "]";
|
|
||||||
}
|
|
||||||
auto wk_ref = iter->second[index];
|
|
||||||
auto wk_ptr = reuse_mem_base_ + wk_ref->offset_;
|
|
||||||
AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(wk_ptr, size, "", kTypeUnknown), index, node.get());
|
|
||||||
index++;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -554,18 +544,7 @@ uint8_t *KernelRuntime::CalDeviceMem(const AnfNodePtr &node, size_t size, int fl
|
||||||
} else if (flag == kDynamicMem) {
|
} else if (flag == kDynamicMem) {
|
||||||
ptr = MallocDynamicMem(size, false);
|
ptr = MallocDynamicMem(size, false);
|
||||||
} else if (flag == kReuseDynamicMem) {
|
} else if (flag == kReuseDynamicMem) {
|
||||||
auto key = node.get();
|
ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index);
|
||||||
auto iter = mem_reuse_util_ptr_->kernel_output_refs_.find(key);
|
|
||||||
if (iter != mem_reuse_util_ptr_->kernel_output_refs_.end()) {
|
|
||||||
// private member form KernelRuntime
|
|
||||||
memreuse::KernelRefCountPtr kernel_ref_count_ptr = mem_reuse_util_ptr_->kernel_output_refs_[key][index];
|
|
||||||
if (kernel_ref_count_ptr == nullptr) {
|
|
||||||
return ptr;
|
|
||||||
}
|
|
||||||
ptr = reuse_mem_base_ + kernel_ref_count_ptr->offset_;
|
|
||||||
} else {
|
|
||||||
MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in kernel_output_refs";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -128,9 +128,6 @@ class KernelRuntime {
|
||||||
size_t total_static_size_ = 0;
|
size_t total_static_size_ = 0;
|
||||||
size_t total_dynamic_size_ = 0;
|
size_t total_dynamic_size_ = 0;
|
||||||
MemReuseUtilPtr mem_reuse_util_ptr_{nullptr};
|
MemReuseUtilPtr mem_reuse_util_ptr_{nullptr};
|
||||||
|
|
||||||
private:
|
|
||||||
uint8_t *reuse_mem_base_{nullptr};
|
|
||||||
};
|
};
|
||||||
using KernelRuntimePtr = std::shared_ptr<KernelRuntime>;
|
using KernelRuntimePtr = std::shared_ptr<KernelRuntime>;
|
||||||
} // namespace device
|
} // namespace device
|
||||||
|
|
|
@ -316,5 +316,35 @@ void MemReuseUtil::SetAllInfo(KernelGraph *graph) {
|
||||||
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph);
|
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint8_t *MemReuseUtil::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const {
|
||||||
|
auto key = node.get();
|
||||||
|
auto iter = kernel_output_refs_.find(key);
|
||||||
|
uint8_t *ptr = nullptr;
|
||||||
|
if (iter != kernel_output_refs_.end()) {
|
||||||
|
if (index >= iter->second.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]";
|
||||||
|
}
|
||||||
|
auto output_ref = iter->second[index];
|
||||||
|
ptr = mem_base_ + output_ref->offset_;
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in kernel_output_refs";
|
||||||
|
}
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t *MemReuseUtil::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const {
|
||||||
|
auto key = node.get();
|
||||||
|
auto iter = kernel_workspace_refs_.find(key);
|
||||||
|
uint8_t *ptr = nullptr;
|
||||||
|
if (iter != kernel_workspace_refs_.end()) {
|
||||||
|
if (index >= iter->second.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]";
|
||||||
|
}
|
||||||
|
auto wk_ref = iter->second[index];
|
||||||
|
ptr = mem_base_ + wk_ref->offset_;
|
||||||
|
}
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
} // namespace memreuse
|
} // namespace memreuse
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -76,6 +76,9 @@ class MemReuseUtil {
|
||||||
void set_kernel_def_ptr_list(const KernelDefPtrMaps &kernel_def_ptr_list) {
|
void set_kernel_def_ptr_list(const KernelDefPtrMaps &kernel_def_ptr_list) {
|
||||||
kernel_def_ptr_list_ = kernel_def_ptr_list;
|
kernel_def_ptr_list_ = kernel_def_ptr_list;
|
||||||
}
|
}
|
||||||
|
void set_mem_base(uint8_t *mem_base) { mem_base_ = mem_base; }
|
||||||
|
uint8_t *GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const;
|
||||||
|
uint8_t *GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int util_index_;
|
int util_index_;
|
||||||
|
@ -88,6 +91,7 @@ class MemReuseUtil {
|
||||||
size_t total_dy_size_ = 0;
|
size_t total_dy_size_ = 0;
|
||||||
size_t total_workspace_size_ = 0;
|
size_t total_workspace_size_ = 0;
|
||||||
size_t total_reuseworkspace_size_ = 0;
|
size_t total_reuseworkspace_size_ = 0;
|
||||||
|
uint8_t *mem_base_{nullptr};
|
||||||
};
|
};
|
||||||
using MemReuseUtilPtr = std::shared_ptr<MemReuseUtil>;
|
using MemReuseUtilPtr = std::shared_ptr<MemReuseUtil>;
|
||||||
} // namespace memreuse
|
} // namespace memreuse
|
||||||
|
|
Loading…
Reference in New Issue