From 8584f8966f92c320c5c9b6482162796738c5b296 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Mon, 15 Mar 2021 21:08:45 +0800 Subject: [PATCH] fix control flow memory problem --- .../backend/session/ascend_auto_monad.cc | 97 +++++++++++++------ 1 file changed, 66 insertions(+), 31 deletions(-) diff --git a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc index c11c3442402..e4aa5651eda 100644 --- a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc +++ b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc @@ -946,12 +946,13 @@ class ExecuteOrderGenerator { graph_->set_execution_order(std::move(execution_order)); } - std::set GetAllNodes() { - auto &all_graphs = context_.visited_graphs(); + std::set GetAllNodes(std::set *search_list) { + const auto &all_graphs = context_.visited_graphs(); std::set all_nodes; for (auto &graph : all_graphs) { auto out = graph->get_return(); MS_EXCEPTION_IF_NULL(out); + search_list->insert(out->cast()); auto nodes = TopoSort(out); for (auto &node : nodes) { MS_EXCEPTION_IF_NULL(node); @@ -971,26 +972,34 @@ class ExecuteOrderGenerator { return input; } - // Erase redundant parameters and assign nodes. - void EraseParameter() { - // Copy out execution order list. - auto exec_order = graph_->execution_order(); - std::set all_nodes = GetAllNodes(); - - // Remove assigns that target and source are same. - for (auto iter = exec_order.begin(); iter != exec_order.end();) { + void RemoveSameInputsAssigns(std::vector *exec_order) { + for (auto iter = exec_order->begin(); iter != exec_order->end();) { auto &node = *iter; auto &inputs = node->inputs(); if (IsPrimitiveCNode(node, prim::kPrimAssign) && (inputs.at(kAssignTargetIndex) == GetRealNode(inputs.at(kAssignSourceIndex)))) { - iter = exec_order.erase(iter); + iter = exec_order->erase(iter); } else { ++iter; } } + } + + // Erase redundant parameters and assign nodes. + void EraseParameter() { + // Copy out execution order list. + auto exec_order = graph_->execution_order(); + std::set search_list(exec_order.begin(), exec_order.end()); + + // Remove assigns that target and source are same. + RemoveSameInputsAssigns(&exec_order); + + // Get all nodes and all graphs + std::set all_nodes = GetAllNodes(&search_list); + auto &all_graphs = context_.visited_graphs(); // Count parameter write times by check all assign nodes. - auto param_write_times = CountParameterAssigns(exec_order); + auto param_write_times = CountParameterAssigns(search_list); // Erase redundant assigns. for (auto iter = exec_order.begin(); iter != exec_order.end();) { @@ -1008,6 +1017,14 @@ class ExecuteOrderGenerator { MS_EXCEPTION_IF_NULL(kg); kg->ReplaceNode(NOT_NULL(target), NOT_NULL(source)); + // replace parameter in graph input + for (auto &g : all_graphs) { + auto child_graph_inputs = g->MutableInputs(); + std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), target, source); + MS_LOG(DEBUG) << "Replace parameter " << target->DebugString() << " by " << source->DebugString() + << " in graph " << g->graph_id() << " inputs"; + } + // replace parameter in node for (auto &iter_node : all_nodes) { for (size_t i = 0; i < iter_node->size(); ++i) { @@ -1018,15 +1035,6 @@ class ExecuteOrderGenerator { } } } - - // replace parameter in graph input - auto &all_graphs = context_.visited_graphs(); - for (auto &g : all_graphs) { - auto child_graph_inputs = g->MutableInputs(); - std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), target, source); - MS_LOG(DEBUG) << "Replace parameter " << target->DebugString() << " by " << source->DebugString() - << " in graph " << g->graph_id() << " inputs"; - } iter = exec_order.erase(iter); continue; } @@ -1039,7 +1047,26 @@ class ExecuteOrderGenerator { } // Count parameter write times by check all assign nodes. - std::map CountParameterAssigns(const std::vector &all_nodes) { + std::map CountParameterAssigns(const std::set &search_list) { + auto ref_map = graph_->GetRefMap(); + std::multimap> ref_multimap; + std::set root_inputs(graph_->inputs().begin(), graph_->inputs().end()); + std::transform(ref_map.begin(), ref_map.end(), std::inserter(ref_multimap, ref_multimap.end()), + [](const std::pair, std::pair> &p) + -> std::pair> { + return {p.first.first, {p.first.second, p.second.first, p.second.second}}; + }); + auto validate_ref_parameter = [](AnfNodePtr node) -> AnfNodePtr { + if (node->isa() && AnfAlgo::CheckPrimitiveType(node, prim::KPrimTransData)) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto first_input = cnode->input(kFirstDataInputIndex); + MS_EXCEPTION_IF_NULL(first_input); + return first_input; + } + return node; + }; + // Find all graph input parameters. std::map param_write_times; const auto &all_graphs = context_.visited_graphs(); @@ -1051,16 +1078,24 @@ class ExecuteOrderGenerator { } } // Search all nodes for parameter write assigns. - for (auto &node : all_nodes) { - if (!IsPrimitiveCNode(node, prim::kPrimAssign)) { - continue; + for (auto &node : search_list) { + std::set refed_parameters; + for (auto [iter, end] = ref_multimap.equal_range(node); iter != end; ++iter) { + refed_parameters.insert(validate_ref_parameter(std::get<1>(iter->second))); } - auto &target = node->inputs().at(kAssignTargetIndex); - MS_EXCEPTION_IF_NULL(target); - auto iter = param_write_times.find(target); - if (iter != param_write_times.end()) { - // Found a parameter writer, count it. - ++(iter->second); + for (auto &in : node->inputs()) { + auto visit_node = AnfAlgo::VisitKernelWithReturnType(in, 0).first; + visit_node = validate_ref_parameter(visit_node); + if (!visit_node->isa() || root_inputs.find(visit_node) != root_inputs.end()) { + continue; + } + if (refed_parameters.find(visit_node) != refed_parameters.end()) { + auto iter = param_write_times.find(visit_node); + if (iter != param_write_times.end()) { + // Found a parameter writer, count it. + ++(iter->second); + } + } } } return param_write_times;