!6803 disable memory reuse for selected op in e2e dump

Merge pull request !6803 from laiyongqiang/e2edump_reuse
This commit is contained in:
mindspore-ci-bot 2020-09-28 09:51:39 +08:00 committed by Gitee
commit 5f4e800141
6 changed files with 53 additions and 1 deletions

View File

@ -331,7 +331,7 @@ void DumpJsonParser::JudgeDumpEnabled() {
e2e_dump_enabled_ = false;
MS_LOG(WARNING) << "Dump not enabled. device_id:" << device_id << " not support";
}
context->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, !e2e_dump_enabled_);
JsonConfigToString();
}

View File

@ -333,6 +333,18 @@ bool AscendKernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel,
return false;
}
bool AscendKernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) {
bool need_dump = false;
auto &dump_json_parser = DumpJsonParser::GetInstance();
if (dump_json_parser.e2e_dump_enabled() && dump_json_parser.dump_mode() == 1) {
auto op_name = node->fullname_with_scope();
if (dump_json_parser.NeedDump(op_name)) {
need_dump = true;
}
}
return need_dump;
}
DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) {
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id);

View File

@ -54,6 +54,7 @@ class AscendKernelRuntime : public KernelRuntime {
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override;
bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override;
bool KernelMemNotReuse(const AnfNodePtr &node) override;
private:
bool InitDevice();

View File

@ -434,7 +434,13 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
if (type == kReuseDynamicMem) {
// reuse communication op's all outputs' memory
type = kReuseDynamicCommMem;
bool not_reuse = KernelMemNotReuse(node);
if (not_reuse) {
type = kDynamicMem;
MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s output.";
}
}
uint8_t *output_ptr = nullptr;
for (size_t j = 0; j < align_size_list.size(); ++j) {
std::string output_format = AnfAlgo::GetOutputFormat(node, j);
@ -451,6 +457,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
output_ptr += align_size_list[j];
}
}
bool KernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) { return false; }
DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);
@ -490,6 +497,15 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP
if (addr_size.empty()) {
return;
}
if (type == kReuseDynamicMem) {
bool not_reuse = KernelMemNotReuse(node);
if (not_reuse) {
type = kDynamicMem;
MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s input.";
}
}
uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, addr_size[0].first);
for (const auto &iter : addr_size) {
MS_EXCEPTION_IF_NULL(iter.first);
@ -513,6 +529,15 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in
type = kDynamicMem;
}
}
if (type == kReuseDynamicMem) {
bool not_reuse = KernelMemNotReuse(node);
if (not_reuse) {
type = kDynamicMem;
MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s output.";
}
}
auto kernel_mod = AnfAlgo::GetKernelMod(node);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto output_sizes = kernel_mod->GetOutputSizeList();
@ -627,9 +652,19 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_enable_mem_reuse = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_REUSE);
auto mem_type = kDynamicMem;
auto &dump_json_parser = DumpJsonParser::GetInstance();
if (dump_json_parser.e2e_dump_enabled() && dump_json_parser.dump_mode() == 0) {
context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_REUSE, false);
is_enable_mem_reuse = false;
MS_LOG(INFO) << "Disable Memory Reuse when e2e dump is enable and dump mode is set to dump all kernels";
}
if (is_enable_mem_reuse) {
MS_LOG(INFO) << "Memory Reuse is enable...";
mem_manager_->MallocReusedDynamicMem(graph);
mem_type = kReuseDynamicMem;
} else {
MS_LOG(INFO) << "Memory Reuse is disable...";
}
auto &execution_nodes = graph->execution_order();
std::vector<CNodePtr> compute_nodes;

View File

@ -82,6 +82,8 @@ class KernelRuntime {
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) = 0;
virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index);
virtual bool KernelMemNotReuse(const AnfNodePtr &node);
void AssignStaticMemory(session::KernelGraph *graph);
void AssignDynamicMemory(session::KernelGraph *graph);
void ReuseAssignDynamicMemory(session::KernelGraph *graph);

View File

@ -42,6 +42,8 @@ void MemoryManager::MallocReusedDynamicMem(const session::KernelGraph *graph) {
MS_LOG(INFO) << "TotalReuseDynamicSize [" << total_allocated_size << "]";
mem_reuse_util_ptr_ = mem_reuse_util_ptr;
auto base_ptr = MallocDynamicMem(total_allocated_size, false);
MS_LOG(INFO) << "Reuse Memory from [" << reinterpret_cast<void *>(base_ptr) << "] to ["
<< reinterpret_cast<void *>(base_ptr + total_allocated_size) << "]";
mem_reuse_util_ptr_->set_mem_base(base_ptr);
}