!13352 fix control flow memory problem

From: @Margaret_wangrui
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-17 09:18:26 +08:00 committed by Gitee
commit 5e2151d0cd
1 changed files with 66 additions and 31 deletions

View File

@ -946,12 +946,13 @@ class ExecuteOrderGenerator {
graph_->set_execution_order(std::move(execution_order));
}
std::set<CNodePtr> GetAllNodes() {
auto &all_graphs = context_.visited_graphs();
std::set<CNodePtr> GetAllNodes(std::set<CNodePtr> *search_list) {
const auto &all_graphs = context_.visited_graphs();
std::set<CNodePtr> all_nodes;
for (auto &graph : all_graphs) {
auto out = graph->get_return();
MS_EXCEPTION_IF_NULL(out);
search_list->insert(out->cast<CNodePtr>());
auto nodes = TopoSort(out);
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
@ -971,26 +972,34 @@ class ExecuteOrderGenerator {
return input;
}
// Erase redundant parameters and assign nodes.
void EraseParameter() {
// Copy out execution order list.
auto exec_order = graph_->execution_order();
std::set<CNodePtr> all_nodes = GetAllNodes();
// Remove assigns that target and source are same.
for (auto iter = exec_order.begin(); iter != exec_order.end();) {
void RemoveSameInputsAssigns(std::vector<CNodePtr> *exec_order) {
for (auto iter = exec_order->begin(); iter != exec_order->end();) {
auto &node = *iter;
auto &inputs = node->inputs();
if (IsPrimitiveCNode(node, prim::kPrimAssign) &&
(inputs.at(kAssignTargetIndex) == GetRealNode(inputs.at(kAssignSourceIndex)))) {
iter = exec_order.erase(iter);
iter = exec_order->erase(iter);
} else {
++iter;
}
}
}
// Erase redundant parameters and assign nodes.
void EraseParameter() {
// Copy out execution order list.
auto exec_order = graph_->execution_order();
std::set<CNodePtr> search_list(exec_order.begin(), exec_order.end());
// Remove assigns that target and source are same.
RemoveSameInputsAssigns(&exec_order);
// Get all nodes and all graphs
std::set<CNodePtr> all_nodes = GetAllNodes(&search_list);
auto &all_graphs = context_.visited_graphs();
// Count parameter write times by check all assign nodes.
auto param_write_times = CountParameterAssigns(exec_order);
auto param_write_times = CountParameterAssigns(search_list);
// Erase redundant assigns.
for (auto iter = exec_order.begin(); iter != exec_order.end();) {
@ -1008,6 +1017,14 @@ class ExecuteOrderGenerator {
MS_EXCEPTION_IF_NULL(kg);
kg->ReplaceNode(NOT_NULL(target), NOT_NULL(source));
// replace parameter in graph input
for (auto &g : all_graphs) {
auto child_graph_inputs = g->MutableInputs();
std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), target, source);
MS_LOG(DEBUG) << "Replace parameter " << target->DebugString() << " by " << source->DebugString()
<< " in graph " << g->graph_id() << " inputs";
}
// replace parameter in node
for (auto &iter_node : all_nodes) {
for (size_t i = 0; i < iter_node->size(); ++i) {
@ -1018,15 +1035,6 @@ class ExecuteOrderGenerator {
}
}
}
// replace parameter in graph input
auto &all_graphs = context_.visited_graphs();
for (auto &g : all_graphs) {
auto child_graph_inputs = g->MutableInputs();
std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), target, source);
MS_LOG(DEBUG) << "Replace parameter " << target->DebugString() << " by " << source->DebugString()
<< " in graph " << g->graph_id() << " inputs";
}
iter = exec_order.erase(iter);
continue;
}
@ -1039,7 +1047,26 @@ class ExecuteOrderGenerator {
}
// Count parameter write times by check all assign nodes.
std::map<AnfNodePtr, int> CountParameterAssigns(const std::vector<CNodePtr> &all_nodes) {
std::map<AnfNodePtr, int> CountParameterAssigns(const std::set<CNodePtr> &search_list) {
auto ref_map = graph_->GetRefMap();
std::multimap<AnfNodePtr, std::tuple<size_t, AnfNodePtr, size_t>> ref_multimap;
std::set<AnfNodePtr> root_inputs(graph_->inputs().begin(), graph_->inputs().end());
std::transform(ref_map.begin(), ref_map.end(), std::inserter(ref_multimap, ref_multimap.end()),
[](const std::pair<std::pair<AnfNodePtr, size_t>, std::pair<AnfNodePtr, size_t>> &p)
-> std::pair<AnfNodePtr, std::tuple<size_t, AnfNodePtr, size_t>> {
return {p.first.first, {p.first.second, p.second.first, p.second.second}};
});
auto validate_ref_parameter = [](AnfNodePtr node) -> AnfNodePtr {
if (node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(node, prim::KPrimTransData)) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto first_input = cnode->input(kFirstDataInputIndex);
MS_EXCEPTION_IF_NULL(first_input);
return first_input;
}
return node;
};
// Find all graph input parameters.
std::map<AnfNodePtr, int> param_write_times;
const auto &all_graphs = context_.visited_graphs();
@ -1051,16 +1078,24 @@ class ExecuteOrderGenerator {
}
}
// Search all nodes for parameter write assigns.
for (auto &node : all_nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimAssign)) {
continue;
for (auto &node : search_list) {
std::set<AnfNodePtr> refed_parameters;
for (auto [iter, end] = ref_multimap.equal_range(node); iter != end; ++iter) {
refed_parameters.insert(validate_ref_parameter(std::get<1>(iter->second)));
}
auto &target = node->inputs().at(kAssignTargetIndex);
MS_EXCEPTION_IF_NULL(target);
auto iter = param_write_times.find(target);
if (iter != param_write_times.end()) {
// Found a parameter writer, count it.
++(iter->second);
for (auto &in : node->inputs()) {
auto visit_node = AnfAlgo::VisitKernelWithReturnType(in, 0).first;
visit_node = validate_ref_parameter(visit_node);
if (!visit_node->isa<Parameter>() || root_inputs.find(visit_node) != root_inputs.end()) {
continue;
}
if (refed_parameters.find(visit_node) != refed_parameters.end()) {
auto iter = param_write_times.find(visit_node);
if (iter != param_write_times.end()) {
// Found a parameter writer, count it.
++(iter->second);
}
}
}
}
return param_write_times;