From dbb86cb1befadbc60b64d80889f1c0b5cfc0cc0a Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Sun, 24 May 2020 18:16:37 +0800 Subject: [PATCH] Adjust some routines of FG and FGM, about the nodes info. IF. --- mindspore/ccsrc/ir/func_graph.cc | 98 ++++++++++++++--------- mindspore/ccsrc/ir/func_graph.h | 32 ++++---- mindspore/ccsrc/ir/func_graph_cloner.cc | 6 +- mindspore/ccsrc/ir/manager.cc | 78 +++++++++--------- mindspore/ccsrc/parallel/step_parallel.cc | 2 +- tests/ut/cpp/ir/manager_test.cc | 4 +- 6 files changed, 116 insertions(+), 104 deletions(-) diff --git a/mindspore/ccsrc/ir/func_graph.cc b/mindspore/ccsrc/ir/func_graph.cc index 4833e3838b7..c5d7639e2ee 100644 --- a/mindspore/ccsrc/ir/func_graph.cc +++ b/mindspore/ccsrc/ir/func_graph.cc @@ -198,7 +198,7 @@ GraphDebugInfoPtr FuncGraph::debug_info() { const AnfNodeSet &FuncGraph::nodes() { return nodes_; } -void FuncGraph::CopyNodes(const AnfNodeSet &other_nodes) { nodes_ = other_nodes; } +void FuncGraph::CopyNodes(const FuncGraphPtr &source) { nodes_ = source->nodes(); } void FuncGraph::ClearNodes() { nodes_.clear(); } @@ -215,7 +215,12 @@ void FuncGraph::DropNode(AnfNodePtr node) { const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; } -void FuncGraph::CopyValueNodes(const AnfNodeCounterMap &other_value_nodes) { value_nodes_ = other_value_nodes; } +void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) { + auto &others = source->value_nodes(); + for (auto it = others.begin(); it != others.end(); it++) { + AddValueNode(it->first, it->second); + } +} void FuncGraph::ClearValueNodes() { value_nodes_.clear(); } @@ -243,9 +248,9 @@ void FuncGraph::DropValueNode(AnfNodePtr node) { const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; } -void FuncGraph::CopyFreeVariables(const AnfNodeCounterMap &others) { - auto it = others.begin(); - for (; it != others.end(); it++) { +void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) { + auto &others = source->free_variables(); + for (auto it = others.begin(); it != others.end(); it++) { if (it->first->func_graph().get() != this) { (void)AddFreeVariable(it->first, it->second); } @@ -313,31 +318,37 @@ std::vector FuncGraph::free_variables_func_graphs() { return func_graphs; } -const AnfNodeCounterMap &FuncGraph::func_graph_value_nodes() { return func_graph_value_nodes_; } +const FuncGraphCounterMap &FuncGraph::func_graphs_used() { return func_graphs_used_; } -void FuncGraph::CopyFuncGraphValueNodes(const AnfNodeCounterMap &others) { func_graph_value_nodes_ = others; } +void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) { + auto &others = source->func_graphs_used(); + for (auto it = others.begin(); it != others.end(); it++) { + (void)AddFuncGraphUsed(it->first, it->second); + } + func_graphs_used_.erase(source); +} -void FuncGraph::ClearFuncGraphValueNodes() { func_graph_value_nodes_.clear(); } +void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); } -bool FuncGraph::AddFuncGraphValueNode(AnfNodePtr node, int count) { - if (func_graph_value_nodes_.count(node) == 0) { - func_graph_value_nodes_[node] = count; +bool FuncGraph::AddFuncGraphUsed(FuncGraphPtr fg, int count) { + if (func_graphs_used_.count(fg) == 0) { + func_graphs_used_[fg] = count; return true; } else { - func_graph_value_nodes_[node] += count; + func_graphs_used_[fg] += count; return false; } } -bool FuncGraph::DropFuncGraphValueNode(AnfNodePtr node) { - if (func_graph_value_nodes_.count(node) != 0) { - if (func_graph_value_nodes_[node] == 1) { - (void)func_graph_value_nodes_.erase(node); +bool FuncGraph::DropFuncGraphUsed(FuncGraphPtr fg) { + if (func_graphs_used_.count(fg) != 0) { + if (func_graphs_used_[fg] == 1) { + (void)func_graphs_used_.erase(fg); return true; } else { - func_graph_value_nodes_[node]--; - if (func_graph_value_nodes_[node] < 0) { - MS_LOG(EXCEPTION) << "Count of value node(FuncGraph) '" << node + func_graphs_used_[fg]--; + if (func_graphs_used_[fg] < 0) { + MS_LOG(EXCEPTION) << "Count of FuncGraph '" << fg << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); } } @@ -354,11 +365,13 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() { const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; } -void FuncGraph::CopyFuncGraphCNodesIndex(const CNodeIndexCounterMap &others) { - auto it = others.begin(); - for (; it != others.end(); it++) { +void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) { + auto &others = source->func_graph_cnodes_index(); + for (auto it = others.begin(); it != others.end(); it++) { // Ignore the user graph who may own itself. - if (it->first->first->func_graph().get() != this) { + auto fg = it->first->first->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + if (fg.get() != this) { AddFuncGraphCNodeIndex(it->first, it->second); } } @@ -388,28 +401,33 @@ void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) { } } -const AnfNodeCounterMap &FuncGraph::j_func_graph_value_nodes() { return j_func_graph_value_nodes_; } +const FuncGraphCounterMap &FuncGraph::j_func_graphs() { return j_func_graphs_; } -void FuncGraph::CopyJFuncGraphValueNodes(const AnfNodeCounterMap &others) { j_func_graph_value_nodes_ = others; } - -void FuncGraph::ClearJFuncGraphValueNodes() { j_func_graph_value_nodes_.clear(); } - -void FuncGraph::AddJFuncGraphValueNode(AnfNodePtr node, int count) { - if (j_func_graph_value_nodes_.count(node) == 0) { - j_func_graph_value_nodes_[node] = count; - } else { - j_func_graph_value_nodes_[node] += count; +void FuncGraph::CopyJFuncGraphs(const FuncGraphPtr &source) { + auto &others = source->j_func_graphs(); + for (auto it = others.begin(); it != others.end(); it++) { + AddJFuncGraph(it->first, it->second); } } -void FuncGraph::DropJFuncGraphValueNode(AnfNodePtr node) { - if (j_func_graph_value_nodes_.count(node) != 0) { - if (j_func_graph_value_nodes_[node] == 1) { - (void)j_func_graph_value_nodes_.erase(node); +void FuncGraph::ClearJFuncGraphs() { j_func_graphs_.clear(); } + +void FuncGraph::AddJFuncGraph(FuncGraphPtr fg, int count) { + if (j_func_graphs_.count(fg) == 0) { + j_func_graphs_[fg] = count; + } else { + j_func_graphs_[fg] += count; + } +} + +void FuncGraph::DropJFuncGraph(FuncGraphPtr fg) { + if (j_func_graphs_.count(fg) != 0) { + if (j_func_graphs_[fg] == 1) { + (void)j_func_graphs_.erase(fg); } else { - j_func_graph_value_nodes_[node]--; - if (j_func_graph_value_nodes_[node] < 0) { - MS_LOG(EXCEPTION) << "Count of value node(J FuncGraph) '" << node + j_func_graphs_[fg]--; + if (j_func_graphs_[fg] < 0) { + MS_LOG(EXCEPTION) << "Count of J FuncGraph '" << fg << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); } } diff --git a/mindspore/ccsrc/ir/func_graph.h b/mindspore/ccsrc/ir/func_graph.h index f4c9d7079f3..8406f3b1ff5 100644 --- a/mindspore/ccsrc/ir/func_graph.h +++ b/mindspore/ccsrc/ir/func_graph.h @@ -189,21 +189,21 @@ class FuncGraph : public FuncGraphBase { // get all nodes belonging to this func graph const AnfNodeSet &nodes(); - void CopyNodes(const AnfNodeSet &other_nodes); + void CopyNodes(const FuncGraphPtr &source); void ClearNodes(); void AddNode(AnfNodePtr node); void DropNode(AnfNodePtr node); // get all value_nodes belonging to this func graph const AnfNodeCounterMap &value_nodes(); - void CopyValueNodes(const AnfNodeCounterMap &other_value_nodes); + void CopyValueNodes(const FuncGraphPtr &source); void ClearValueNodes(); void AddValueNode(AnfNodePtr node, int count = 1); void DropValueNode(AnfNodePtr node); // get all free vars directly used in this func graph const AnfNodeCounterMap &free_variables(); - void CopyFreeVariables(const AnfNodeCounterMap &others); + void CopyFreeVariables(const FuncGraphPtr &source); void ClearFreeVariables(); bool AddFreeVariable(AnfNodePtr node, int count = 1); bool DropFreeVariable(AnfNodePtr node); @@ -218,25 +218,25 @@ class FuncGraph : public FuncGraphBase { std::vector free_variables_func_graphs(); // get all value nodes of func graph directly used by this func graph - const AnfNodeCounterMap &func_graph_value_nodes(); - void CopyFuncGraphValueNodes(const AnfNodeCounterMap &others); - void ClearFuncGraphValueNodes(); - bool AddFuncGraphValueNode(AnfNodePtr node, int count = 1); - bool DropFuncGraphValueNode(AnfNodePtr node); + const FuncGraphCounterMap &func_graphs_used(); + void CopyFuncGraphsUsed(const FuncGraphPtr &source); + void ClearFuncGraphsUsed(); + bool AddFuncGraphUsed(FuncGraphPtr fg, int count = 1); + bool DropFuncGraphUsed(FuncGraphPtr fg); // get all value nodes of J func graph directly used by this func graph - const AnfNodeCounterMap &j_func_graph_value_nodes(); - void CopyJFuncGraphValueNodes(const AnfNodeCounterMap &others); - void ClearJFuncGraphValueNodes(); - void AddJFuncGraphValueNode(AnfNodePtr node, int count = 1); - void DropJFuncGraphValueNode(AnfNodePtr node); + const FuncGraphCounterMap &j_func_graphs(); + void CopyJFuncGraphs(const FuncGraphPtr &source); + void ClearJFuncGraphs(); + void AddJFuncGraph(FuncGraphPtr fg, int count = 1); + void DropJFuncGraph(FuncGraphPtr fg); // get all func graphs nested used by this func graph const FuncGraphSet &func_graphs_used_total(); // get all user value nodes of this func graph, by CNode and its input's index const CNodeIndexCounterMap &func_graph_cnodes_index(); - void CopyFuncGraphCNodesIndex(const CNodeIndexCounterMap &other_value_nodes); + void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source); void ClearFuncGraphCNodesIndex(); void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1); void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node); @@ -311,13 +311,13 @@ class FuncGraph : public FuncGraphBase { AnfNodeCounterMap value_nodes_; // all func graph value nodes of the function - AnfNodeCounterMap func_graph_value_nodes_; + FuncGraphCounterMap func_graphs_used_; // all free variables of the function AnfNodeCounterMap free_variables_; // all value nodes calling J in the function - AnfNodeCounterMap j_func_graph_value_nodes_; + FuncGraphCounterMap j_func_graphs_; // all user value nodes of this func graph, recording by CNode and its input's index CNodeIndexCounterMap func_graph_cnodes_index_; diff --git a/mindspore/ccsrc/ir/func_graph_cloner.cc b/mindspore/ccsrc/ir/func_graph_cloner.cc index 99d7c316e99..db52e08348a 100644 --- a/mindspore/ccsrc/ir/func_graph_cloner.cc +++ b/mindspore/ccsrc/ir/func_graph_cloner.cc @@ -153,9 +153,9 @@ void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) { if (!clone_all_used_graphs_) { return; } - auto &used = func_graph->func_graph_value_nodes(); - for (auto &fg_value_node : used) { - todo_.push_back({GetValueNode(fg_value_node.first), nullptr, {}}); + auto &used = func_graph->func_graphs_used(); + for (auto &fg : used) { + todo_.push_back({fg.first, nullptr, {}}); } } diff --git a/mindspore/ccsrc/ir/manager.cc b/mindspore/ccsrc/ir/manager.cc index cfaa84a05bd..a21a794fee2 100644 --- a/mindspore/ccsrc/ir/manager.cc +++ b/mindspore/ccsrc/ir/manager.cc @@ -196,7 +196,6 @@ void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) { return; } AddIntoManaged(func_graph); - MS_EXCEPTION_IF_NULL(signals_); std::vector para = func_graph->parameters(); AcquireNodes(para); std::vector return_vec({func_graph->get_return()}); @@ -301,7 +300,6 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool std::vector return_vec = {func_graph->get_return()}; todo.update(MaybeDropNodes(return_vec)); } - MS_EXCEPTION_IF_NULL(signals_); for (auto &fg : dropped) { MS_EXCEPTION_IF_NULL(fg); all_nodes_.difference_update(fg->parameters()); @@ -334,7 +332,6 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E } auto &users_node = node_users_[inp]; users_node.add(make_pair(node, index)); - MS_EXCEPTION_IF_NULL(signals_); AddEdge(node, index, inp); } } @@ -384,8 +381,6 @@ void FuncGraphManager::AcquireNodes(const std::vector &nodes) { FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector &nodes) { AnfNodeSet nodes_ordered(nodes); FuncGraphSetPtr func_graphs_to_check = std::make_shared(); - MS_EXCEPTION_IF_NULL(signals_); - while (!nodes_ordered.empty()) { AnfNodePtr node = nodes_ordered.pop(); MS_EXCEPTION_IF_NULL(node); @@ -475,13 +470,13 @@ inline void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr inp if (input->isa()) { fg->AddValueNode(input); if (IsValueNode(input)) { - if (fg->AddFuncGraphValueNode(input)) { - signals_->InvalidateComputer(); - } auto used = GetValueNode(input); used->AddFuncGraphCNodeIndex(std::make_shared(std::make_pair(node, index))); + if (fg->AddFuncGraphUsed(used)) { + signals_->InvalidateComputer(); + } if (IsPrimitiveCNode(node, prim::kPrimJ)) { - fg->AddJFuncGraphValueNode(input); + fg->AddJFuncGraph(used); } } } else if (fg != nullptr && fg != input->func_graph()) { @@ -496,13 +491,13 @@ inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr in if (input->isa()) { fg->DropValueNode(input); if (IsValueNode(input)) { - if (fg->DropFuncGraphValueNode(input)) { - signals_->InvalidateComputer(); - } auto used = GetValueNode(input); used->DropFuncGraphCNodeIndex(std::make_shared(std::make_pair(node, index))); + if (fg->DropFuncGraphUsed(used)) { + signals_->InvalidateComputer(); + } if (IsPrimitiveCNode(node, prim::kPrimJ)) { - fg->DropJFuncGraphValueNode(input); + fg->DropJFuncGraph(used); } } } else if (fg != nullptr && fg != input->func_graph()) { @@ -513,19 +508,19 @@ inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr in } inline void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { - target->CopyNodes(source->nodes()); - target->CopyValueNodes(source->value_nodes()); - target->CopyFuncGraphCNodesIndex(source->func_graph_cnodes_index()); - target->CopyFreeVariables(source->free_variables()); - target->CopyFuncGraphValueNodes(source->func_graph_value_nodes()); - target->CopyJFuncGraphValueNodes(source->j_func_graph_value_nodes()); + target->CopyNodes(source); + target->CopyValueNodes(source); + target->CopyFuncGraphCNodesIndex(source); + target->CopyFreeVariables(source); + target->CopyFuncGraphsUsed(source); + target->CopyJFuncGraphs(source); signals_->InvalidateComputer(); source->ClearNodes(); source->ClearValueNodes(); source->ClearFuncGraphCNodesIndex(); source->ClearFreeVariables(); - source->ClearFuncGraphValueNodes(); - source->ClearJFuncGraphValueNodes(); + source->ClearFuncGraphsUsed(); + source->ClearJFuncGraphs(); } FuncGraphTransaction FuncGraphManager::Transact() { @@ -768,10 +763,10 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f } // Search the fv in fg's child func graph. - auto &fg_value_nodes = fg->func_graph_value_nodes(); - for (auto &fg_value_node : fg_value_nodes) { + auto &fgs = fg->func_graphs_used(); + for (auto &item : fgs) { fg->seen_ = seen_num; - auto gt = GetValueNode(fg_value_node.first); + auto gt = item.first; parents->update(SeekParents(gt, seen_num)); } (void)parents->erase(fg); @@ -865,15 +860,15 @@ void FVTotalComputer::RealRecompute() { } } - auto &used = fg->func_graph_value_nodes(); + auto &used = fg->func_graphs_used(); for (auto &iter : used) { - auto p = manager->parent(GetValueNode(iter.first)); + auto p = manager->parent(iter.first); if (p == nullptr) { continue; } auto curr = fg; while (curr != p) { - (void)CounterFuncGraphCollector::Mod(curr, GetValueNode(iter.first), iter.second); + (void)CounterFuncGraphCollector::Mod(curr, iter.first, iter.second); curr = manager->parent(curr); } } @@ -899,8 +894,8 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { while (!todo.empty()) { todo_new.clear(); for (auto > : todo) { - for (auto &item : gt->func_graph_value_nodes()) { - auto used_fg = GetValueNode(item.first); + for (auto &item : gt->func_graphs_used()) { + auto used_fg = item.first; if (used_fg == fg) { func_graph_used_total_analysis_[fg].add(used_fg); continue; @@ -925,8 +920,8 @@ bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &f while (!todo.empty()) { todo_new.clear(); for (auto > : todo) { - for (auto &item : gt->func_graph_value_nodes()) { - auto used_g = GetValueNode(item.first); + for (auto &item : gt->func_graphs_used()) { + auto used_g = item.first; if (used_g == fg) { return true; } @@ -957,9 +952,9 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::listpush_back(fg); - auto &items = fg->func_graph_value_nodes(); + auto &items = fg->func_graphs_used(); for (auto iter = items.begin(); iter != items.end(); (void)iter++) { - CheckRecursiveGraphs(GetValueNode(iter->first), trace); + CheckRecursiveGraphs(iter->first, trace); } trace->pop_back(); if (!recursive_map_.count(fg)) { @@ -973,14 +968,13 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) { MS_LOG(DEBUG) << fg->ToString() << " had been checked"; return false; } - auto &j_fg_value_nodes = fg->j_func_graph_value_nodes(); - if (!j_fg_value_nodes.empty()) { + auto &j_fgs = fg->j_func_graphs(); + if (!j_fgs.empty()) { // check g1->J(fg)->g2->g cycle; - auto contains_j = - std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(), [seen_num](const std::pair iter) { - return GetValueNode(iter.first)->seen_ != seen_num; - }); - if (contains_j != j_fg_value_nodes.end()) { + auto contains_j = std::find_if(j_fgs.begin(), j_fgs.end(), [seen_num](const std::pair iter) { + return iter.first->seen_ != seen_num; + }); + if (contains_j != j_fgs.end()) { MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")"; return true; } @@ -988,8 +982,8 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) { fg->seen_ = seen_num; // check if func graphs used contains J(func_graph); - for (auto &item : fg->func_graph_value_nodes()) { - auto used_g = GetValueNode(item.first); + for (auto &item : fg->func_graphs_used()) { + auto used_g = item.first; if (SeekJ(used_g, seen_num)) { MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)"; return true; diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 3b679a473fe..6c3b51347f1 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -2187,7 +2187,7 @@ void MarkForwardCNode(const FuncGraphPtr &root) { SetForwardFlag(all_nodes); } else { for (auto &func_graph : graph_set) { - MS_LOG(INFO) << "The sub graph size of root is " << root->func_graph_value_nodes().size(); + MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size(); auto return_node = func_graph->get_return(); MS_EXCEPTION_IF_NULL(return_node); auto all_dfs_nodes = DeepLinkedGraphSearch(return_node); diff --git a/tests/ut/cpp/ir/manager_test.cc b/tests/ut/cpp/ir/manager_test.cc index 25e66036a18..7b1e4d8554f 100644 --- a/tests/ut/cpp/ir/manager_test.cc +++ b/tests/ut/cpp/ir/manager_test.cc @@ -462,8 +462,8 @@ TEST_F(TestManager, test_nested_manual) { ASSERT_EQ(1, iter.second.size()); } - ASSERT_EQ(1, f->func_graph_value_nodes().size()); - ASSERT_EQ(0, g->func_graph_value_nodes().size()); + ASSERT_EQ(1, f->func_graphs_used().size()); + ASSERT_EQ(0, g->func_graphs_used().size()); ASSERT_EQ(0, f->free_variables().size()); ASSERT_EQ(1, g->free_variables().size());