forked from mindspore-Ecosystem/mindspore
!13352 fix control flow memory problem
From: @Margaret_wangrui Reviewed-by: Signed-off-by:
This commit is contained in:
commit
5e2151d0cd
|
@ -946,12 +946,13 @@ class ExecuteOrderGenerator {
|
|||
graph_->set_execution_order(std::move(execution_order));
|
||||
}
|
||||
|
||||
std::set<CNodePtr> GetAllNodes() {
|
||||
auto &all_graphs = context_.visited_graphs();
|
||||
std::set<CNodePtr> GetAllNodes(std::set<CNodePtr> *search_list) {
|
||||
const auto &all_graphs = context_.visited_graphs();
|
||||
std::set<CNodePtr> all_nodes;
|
||||
for (auto &graph : all_graphs) {
|
||||
auto out = graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(out);
|
||||
search_list->insert(out->cast<CNodePtr>());
|
||||
auto nodes = TopoSort(out);
|
||||
for (auto &node : nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -971,26 +972,34 @@ class ExecuteOrderGenerator {
|
|||
return input;
|
||||
}
|
||||
|
||||
// Erase redundant parameters and assign nodes.
|
||||
void EraseParameter() {
|
||||
// Copy out execution order list.
|
||||
auto exec_order = graph_->execution_order();
|
||||
std::set<CNodePtr> all_nodes = GetAllNodes();
|
||||
|
||||
// Remove assigns that target and source are same.
|
||||
for (auto iter = exec_order.begin(); iter != exec_order.end();) {
|
||||
void RemoveSameInputsAssigns(std::vector<CNodePtr> *exec_order) {
|
||||
for (auto iter = exec_order->begin(); iter != exec_order->end();) {
|
||||
auto &node = *iter;
|
||||
auto &inputs = node->inputs();
|
||||
if (IsPrimitiveCNode(node, prim::kPrimAssign) &&
|
||||
(inputs.at(kAssignTargetIndex) == GetRealNode(inputs.at(kAssignSourceIndex)))) {
|
||||
iter = exec_order.erase(iter);
|
||||
iter = exec_order->erase(iter);
|
||||
} else {
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Erase redundant parameters and assign nodes.
|
||||
void EraseParameter() {
|
||||
// Copy out execution order list.
|
||||
auto exec_order = graph_->execution_order();
|
||||
std::set<CNodePtr> search_list(exec_order.begin(), exec_order.end());
|
||||
|
||||
// Remove assigns that target and source are same.
|
||||
RemoveSameInputsAssigns(&exec_order);
|
||||
|
||||
// Get all nodes and all graphs
|
||||
std::set<CNodePtr> all_nodes = GetAllNodes(&search_list);
|
||||
auto &all_graphs = context_.visited_graphs();
|
||||
|
||||
// Count parameter write times by check all assign nodes.
|
||||
auto param_write_times = CountParameterAssigns(exec_order);
|
||||
auto param_write_times = CountParameterAssigns(search_list);
|
||||
|
||||
// Erase redundant assigns.
|
||||
for (auto iter = exec_order.begin(); iter != exec_order.end();) {
|
||||
|
@ -1008,6 +1017,14 @@ class ExecuteOrderGenerator {
|
|||
MS_EXCEPTION_IF_NULL(kg);
|
||||
kg->ReplaceNode(NOT_NULL(target), NOT_NULL(source));
|
||||
|
||||
// replace parameter in graph input
|
||||
for (auto &g : all_graphs) {
|
||||
auto child_graph_inputs = g->MutableInputs();
|
||||
std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), target, source);
|
||||
MS_LOG(DEBUG) << "Replace parameter " << target->DebugString() << " by " << source->DebugString()
|
||||
<< " in graph " << g->graph_id() << " inputs";
|
||||
}
|
||||
|
||||
// replace parameter in node
|
||||
for (auto &iter_node : all_nodes) {
|
||||
for (size_t i = 0; i < iter_node->size(); ++i) {
|
||||
|
@ -1018,15 +1035,6 @@ class ExecuteOrderGenerator {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// replace parameter in graph input
|
||||
auto &all_graphs = context_.visited_graphs();
|
||||
for (auto &g : all_graphs) {
|
||||
auto child_graph_inputs = g->MutableInputs();
|
||||
std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), target, source);
|
||||
MS_LOG(DEBUG) << "Replace parameter " << target->DebugString() << " by " << source->DebugString()
|
||||
<< " in graph " << g->graph_id() << " inputs";
|
||||
}
|
||||
iter = exec_order.erase(iter);
|
||||
continue;
|
||||
}
|
||||
|
@ -1039,7 +1047,26 @@ class ExecuteOrderGenerator {
|
|||
}
|
||||
|
||||
// Count parameter write times by check all assign nodes.
|
||||
std::map<AnfNodePtr, int> CountParameterAssigns(const std::vector<CNodePtr> &all_nodes) {
|
||||
std::map<AnfNodePtr, int> CountParameterAssigns(const std::set<CNodePtr> &search_list) {
|
||||
auto ref_map = graph_->GetRefMap();
|
||||
std::multimap<AnfNodePtr, std::tuple<size_t, AnfNodePtr, size_t>> ref_multimap;
|
||||
std::set<AnfNodePtr> root_inputs(graph_->inputs().begin(), graph_->inputs().end());
|
||||
std::transform(ref_map.begin(), ref_map.end(), std::inserter(ref_multimap, ref_multimap.end()),
|
||||
[](const std::pair<std::pair<AnfNodePtr, size_t>, std::pair<AnfNodePtr, size_t>> &p)
|
||||
-> std::pair<AnfNodePtr, std::tuple<size_t, AnfNodePtr, size_t>> {
|
||||
return {p.first.first, {p.first.second, p.second.first, p.second.second}};
|
||||
});
|
||||
auto validate_ref_parameter = [](AnfNodePtr node) -> AnfNodePtr {
|
||||
if (node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(node, prim::KPrimTransData)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto first_input = cnode->input(kFirstDataInputIndex);
|
||||
MS_EXCEPTION_IF_NULL(first_input);
|
||||
return first_input;
|
||||
}
|
||||
return node;
|
||||
};
|
||||
|
||||
// Find all graph input parameters.
|
||||
std::map<AnfNodePtr, int> param_write_times;
|
||||
const auto &all_graphs = context_.visited_graphs();
|
||||
|
@ -1051,16 +1078,24 @@ class ExecuteOrderGenerator {
|
|||
}
|
||||
}
|
||||
// Search all nodes for parameter write assigns.
|
||||
for (auto &node : all_nodes) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimAssign)) {
|
||||
continue;
|
||||
for (auto &node : search_list) {
|
||||
std::set<AnfNodePtr> refed_parameters;
|
||||
for (auto [iter, end] = ref_multimap.equal_range(node); iter != end; ++iter) {
|
||||
refed_parameters.insert(validate_ref_parameter(std::get<1>(iter->second)));
|
||||
}
|
||||
auto &target = node->inputs().at(kAssignTargetIndex);
|
||||
MS_EXCEPTION_IF_NULL(target);
|
||||
auto iter = param_write_times.find(target);
|
||||
if (iter != param_write_times.end()) {
|
||||
// Found a parameter writer, count it.
|
||||
++(iter->second);
|
||||
for (auto &in : node->inputs()) {
|
||||
auto visit_node = AnfAlgo::VisitKernelWithReturnType(in, 0).first;
|
||||
visit_node = validate_ref_parameter(visit_node);
|
||||
if (!visit_node->isa<Parameter>() || root_inputs.find(visit_node) != root_inputs.end()) {
|
||||
continue;
|
||||
}
|
||||
if (refed_parameters.find(visit_node) != refed_parameters.end()) {
|
||||
auto iter = param_write_times.find(visit_node);
|
||||
if (iter != param_write_times.end()) {
|
||||
// Found a parameter writer, count it.
|
||||
++(iter->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return param_write_times;
|
||||
|
|
Loading…
Reference in New Issue