diff --git a/mindspore/ccsrc/ir/func_graph.cc b/mindspore/ccsrc/ir/func_graph.cc index 8a58f320f13..40417a33da9 100644 --- a/mindspore/ccsrc/ir/func_graph.cc +++ b/mindspore/ccsrc/ir/func_graph.cc @@ -263,18 +263,15 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() { return used; } -const FuncGraphCounterMap &FuncGraph::func_graph_users() { +const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { auto mng = manager_.lock(); + if (mng == nullptr) { + MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString() + << " NodeInfo: " << trace::GetDebugInfo(debug_info()); + } MS_EXCEPTION_IF_NULL(mng); - auto &users = mng->func_graph_users(); - return users[shared_from_base()]; -} - -const AnfNodeCounterMap &FuncGraph::func_graph_user_cnodes() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - auto &users = mng->func_graph_user_cnodes(); - return users[shared_from_base()]; + auto &cnode = mng->func_graph_cnodes_index(); + return cnode[shared_from_base()]; } FuncGraphPtr FuncGraph::parent() { diff --git a/mindspore/ccsrc/ir/func_graph.h b/mindspore/ccsrc/ir/func_graph.h index 9c3752cd816..bca57598078 100644 --- a/mindspore/ccsrc/ir/func_graph.h +++ b/mindspore/ccsrc/ir/func_graph.h @@ -37,6 +37,7 @@ namespace mindspore { using BaseRefCounterMap = OrderedMap; using FuncGraphCounterMap = OrderedMap; using AnfNodeCounterMap = OrderedMap; +using CNodeIndexCounterMap = OrderedMap; const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; @@ -203,11 +204,8 @@ class FuncGraph : public FuncGraphBase { // get all func graphs nested used by this func graph const FuncGraphSet &func_graphs_used_total(); - // get all users of this func graph - const FuncGraphCounterMap &func_graph_users(); - - // get all user cnodes of this func graph - const AnfNodeCounterMap &func_graph_user_cnodes(); + // get all user value nodes of this func graph + const CNodeIndexCounterMap &func_graph_cnodes_index(); // Return the parent of this graph. FuncGraphPtr parent(); diff --git a/mindspore/ccsrc/ir/func_graph_cloner.cc b/mindspore/ccsrc/ir/func_graph_cloner.cc index c086b8d7d18..c8012276f13 100644 --- a/mindspore/ccsrc/ir/func_graph_cloner.cc +++ b/mindspore/ccsrc/ir/func_graph_cloner.cc @@ -182,9 +182,11 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func } target_func_graph->set_return(return_node); - auto &value_nodes = manager_->func_graph_valuenodes()[func_graph]; - for (auto &value_node : value_nodes) { - CloneValueNode(value_node.first, target_func_graph); + auto &cnodes = manager_->func_graph_cnodes_index()[func_graph]; + for (auto &cnode : cnodes) { + auto parent = cnode.first->first->cast(); + auto valuenode = parent->input(cnode.first->second); + CloneValueNode(valuenode, target_func_graph); } } @@ -386,8 +388,8 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph if (lift_params.empty()) { return; } - for (auto &user : func_graph_user->func_graph_users()) { - LiftParameters(user.first, func_graph_user, lift_params); + for (auto &cnode : func_graph_user->func_graph_cnodes_index()) { + LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params); } } @@ -395,8 +397,8 @@ void Cloner::Lift() { for (auto &func_graph_params : repl_func_graph_params_) { auto &func_graph = func_graph_params.first; auto ¶ms = func_graph_params.second; - for (auto &user : func_graph->func_graph_users()) { - LiftParameters(user.first, func_graph, params); + for (auto &cnode : func_graph->func_graph_cnodes_index()) { + LiftParameters(cnode.first->first->func_graph(), func_graph, params); } } } diff --git a/mindspore/ccsrc/ir/manager.cc b/mindspore/ccsrc/ir/manager.cc index 150e68ef4db..1ed747eefdc 100644 --- a/mindspore/ccsrc/ir/manager.cc +++ b/mindspore/ccsrc/ir/manager.cc @@ -78,13 +78,16 @@ void FuncGraphManager::Reset() { node_users_ = NodeUsersMap(); signals_ = std::make_shared(); + // FuncGraph --> AnfNode nodes_ = std::make_shared(this); + + // FuncGraph --> {AnfNode, Count} valuenodes_ = std::make_shared(this); free_variables_direct_ = std::make_shared(this); - func_graph_valuenodes_ = std::make_shared(this); + func_graph_cnodes_index_ = std::make_shared(this); + + // FuncGraph --> {FuncGraph, Count} func_graphs_used_ = std::make_shared(this); - func_graph_users_ = std::make_shared(this); - func_graph_user_cnodes_ = std::make_shared(this); func_graph_child_direct_ = std::make_shared(this); func_graph_parents_direct_ = std::make_shared(this); func_graph_j_direct_ = std::make_shared(this); @@ -300,9 +303,9 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString(); continue; } - MS_EXCEPTION_IF_NULL(func_graph_users_); - auto &users = func_graph_users_->count_func_graphs_map()[func_graph]; - if (!users.empty() && !ignore_users) { + MS_EXCEPTION_IF_NULL(func_graph_cnodes_index_); + auto &users_cnode_index = func_graph_cnodes_index_->count_nodes_map()[func_graph]; + if (!users_cnode_index.empty() && !ignore_users) { MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); continue; } @@ -472,10 +475,6 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t node->set_scope(scope); } } - for (auto &used : source->func_graphs_used()) { - (void)func_graph_users_->Inc(used.first, target, used.second); - (void)this->func_graph_users()[used.first].erase(source); - } for (auto &child : this->func_graph_child_direct()[source]) { (void)func_graph_parents_direct_->Inc(child.first, target, child.second); (void)this->func_graph_parents_direct()[child.first].erase(source); @@ -661,7 +660,9 @@ DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAna void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } -bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { +template +bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, + const ValueT &key, int count) { auto &d = count_nodes_map_[func_graph]; if (d.count(key) == 0) { d[key] = count; @@ -672,7 +673,9 @@ bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodeP return false; } -bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { +template +bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, + const ValueT &key, int count) { MS_EXCEPTION_IF_NULL(func_graph); auto &d = count_nodes_map_[func_graph]; if (d.count(key) != 0) { @@ -682,7 +685,7 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodeP } else { d[key] -= count; if (d[key] < 0) { - MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() + MS_LOG(EXCEPTION) << "Count of key '" << key << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); } } @@ -690,17 +693,78 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodeP return false; } -bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count) { +template +bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, + const ValueT &key, int count) { if (count > 0) { return Inc(func_graph, key, count); } else if (count < 0) { return Dec(func_graph, key, -count); } else { - MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() + MS_LOG(EXCEPTION) << "Count of key '" << key << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); } } +void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { + MS_EXCEPTION_IF_NULL(node); + if (inp->isa()) { + (void)Mod(node->func_graph(), inp, direction); + } +} + +void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { + for (auto &it : count_nodes_map_[src]) { + (void)Inc(dst, it.first, it.second); + } + (void)count_nodes_map_.erase(src); +} + +void FuncGraphUsersCNodeIndexCollector::OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, + EdgeProcessDirection direction) { + MS_EXCEPTION_IF_NULL(node); + if (IsValueNode(inp)) { + (void)Mod(GetValueNode(inp), std::make_shared(std::make_pair(node, index)), + direction); + } +} + +void FuncGraphUsersCNodeIndexCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { + for (auto &it : count_nodes_map_[src]) { + // Ignore the user graph who may own itself. + if (dst != it.first->first->func_graph()) { + (void)Inc(dst, it.first, it.second); + } + } + (void)count_nodes_map_.erase(src); +} + +void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(inp); + FuncGraphPtr fg1 = node->func_graph(); + FuncGraphPtr fg2 = inp->func_graph(); + if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) { + (void)Mod(fg1, inp, direction); + } +} + +void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { + for (auto &it : count_nodes_map_[src]) { + FuncGraphPtr fg2 = it.first->func_graph(); + if (fg2 != dst) { + (void)Inc(dst, it.first, it.second); + } + } + (void)count_nodes_map_.erase(src); +} + +static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) { + FuncGraphPtr gn = std::make_shared(); + (void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg))); + return gn; +} + bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { auto &d = count_func_graphs_map_[func_graph]; if (d.count(key) == 0) { @@ -740,60 +804,6 @@ bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGr } } -void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - if (inp->isa()) { - (void)Mod(node->func_graph(), inp, direction); - } -} - -void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto &it : count_nodes_map_[src]) { - (void)Inc(dst, it.first, it.second); - } - (void)count_nodes_map_.erase(src); -} - -// if inp is a graph ValueNode, this graph's FuncGraphValueNodesCollector's value is inp self -void FuncGraphValueNodesCollector::OnModEdge(AnfNodePtr, int, AnfNodePtr inp, EdgeProcessDirection direction) { - if (IsValueNode(inp)) { - (void)Mod(GetValueNode(inp), inp, direction); - } -} - -void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto &it : count_nodes_map_[src]) { - (void)Inc(dst, it.first, it.second); - } - (void)count_nodes_map_.erase(src); -} - -void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(inp); - FuncGraphPtr fg1 = node->func_graph(); - FuncGraphPtr fg2 = inp->func_graph(); - if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) { - (void)Mod(fg1, inp, direction); - } -} - -void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto &it : count_nodes_map_[src]) { - FuncGraphPtr fg2 = it.first->func_graph(); - if (fg2 != dst) { - (void)Inc(dst, it.first, it.second); - } - } - (void)count_nodes_map_.erase(src); -} - -static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) { - FuncGraphPtr gn = std::make_shared(); - (void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg))); - return gn; -} - void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(inp); @@ -859,32 +869,6 @@ void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) (void)count_func_graphs_map_.erase(src); } -void FuncGraphUsersCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - if (IsValueNode(inp)) { - (void)Mod(GetValueNode(inp), node->func_graph(), direction); - } -} - -void FuncGraphUsersCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr) { - // all graph use in src need to change to dst, so add dst user - (void)count_func_graphs_map_.erase(src); -} - -void FuncGraphUserNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - if (IsValueNode(inp)) { - (void)Mod(GetValueNode(inp), node, direction); - } -} - -void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto &it : count_nodes_map_[src]) { - (void)Inc(dst, it.first, it.second); - } - (void)count_nodes_map_.erase(src); -} - void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { if (IsValueNode(inp) && IsPrimitiveCNode(node, prim::kPrimJ)) { (void)Mod(node->func_graph(), GetValueNode(inp), direction); diff --git a/mindspore/ccsrc/ir/manager.h b/mindspore/ccsrc/ir/manager.h index 54c1e8a6923..7f36b532056 100644 --- a/mindspore/ccsrc/ir/manager.h +++ b/mindspore/ccsrc/ir/manager.h @@ -100,8 +100,12 @@ struct Signals { enum EdgeProcessDirection { kDecEdge = -1, kIncEdge = 1 }; +using CNodeIndexPair = std::pair; +using CNodeIndexPairPtr = std::shared_ptr; + using FuncGraphToFuncGraphCounterMap = OrderedMap>; -using FuncGraphToAnfNodeCounterMap = OrderedMap>; +template , class CollectorEqual = std::equal_to> +using FuncGraphToAnfNodeCounterMap = OrderedMap>; // analysis base class class FuncGraphAnalysis { @@ -174,6 +178,87 @@ class NodesCollector final : public DepCollector { void OnDropNode(AnfNodePtr n) override; }; +struct CNodeIndexHasher { + std::size_t operator()(const CNodeIndexPairPtr pair) const { + MS_EXCEPTION_IF_NULL(pair); + MS_EXCEPTION_IF_NULL(pair->first); + return hash_combine(pair->first->hash(), std::hash()(pair->second)); + } +}; + +struct CNodeIndexEqual { + bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const { + if (lhs == nullptr || rhs == nullptr) { + return false; + } + if (lhs == rhs) { + return true; + } + if (lhs->first != rhs->first) { + return false; + } + if (lhs->second != rhs->second) { + return false; + } + return true; + } +}; + +template , class CollectorEqual = std::equal_to> +class CounterAnfNodeCollector : public DepCollector { + public: + explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} + ~CounterAnfNodeCollector() override = default; + FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; } + + size_t size() const override { return count_nodes_map_.size(); } + void OnAddFuncGraph(FuncGraphPtr fg) final { + count_nodes_map_[fg] = OrderedMap(); + } + void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } + + bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count); + bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count); + bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count); + + FuncGraphToAnfNodeCounterMap count_nodes_map_; + + protected: + void ExtraReset() override { count_nodes_map_.clear(); } +}; + +class ValueNodesCollector final : public CounterAnfNodeCollector { + public: + explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} + ~ValueNodesCollector() override = default; + void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; + + protected: + void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; +}; + +// Record the CNode and its input index, who points to the function graph. +class FuncGraphUsersCNodeIndexCollector final + : public CounterAnfNodeCollector { + public: + explicit FuncGraphUsersCNodeIndexCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} + ~FuncGraphUsersCNodeIndexCollector() override = default; + void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; + + protected: + void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; +}; + +class FVDirectCollector final : public CounterAnfNodeCollector { + public: + explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} + ~FVDirectCollector() override = default; + void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; + + protected: + void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; +}; + class CounterFuncGraphCollector : public DepCollector { public: explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {} @@ -193,56 +278,6 @@ class CounterFuncGraphCollector : public DepCollector { void ExtraReset() override { count_func_graphs_map_.clear(); } }; -class CounterAnfNodeCollector : public DepCollector { - public: - explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} - ~CounterAnfNodeCollector() override = default; - FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; } - - size_t size() const override { return count_nodes_map_.size(); } - void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap(); } - void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } - - bool Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); - bool Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); - bool Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); - - FuncGraphToAnfNodeCounterMap count_nodes_map_; - - protected: - void ExtraReset() override { count_nodes_map_.clear(); } -}; - -class ValueNodesCollector final : public CounterAnfNodeCollector { - public: - explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} - ~ValueNodesCollector() override = default; - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; -}; - -class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector { - public: - explicit FuncGraphValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} - ~FuncGraphValueNodesCollector() override = default; - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; -}; - -class FVDirectCollector final : public CounterAnfNodeCollector { - public: - explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} - ~FVDirectCollector() override = default; - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; -}; - class FuncGraphChildDirect final : public CounterFuncGraphCollector { public: explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} @@ -279,28 +314,6 @@ class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; }; -// graph's all user graphs: key is g, value is graphs who used g -class FuncGraphUsersCollector final : public CounterFuncGraphCollector { - public: - explicit FuncGraphUsersCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - ~FuncGraphUsersCollector() override = default; - - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; -}; - -// graph's all user cnodes: key is g, value is cnodes who used g -class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector { - public: - explicit FuncGraphUserNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - ~FuncGraphUserNodesCollector() override = default; - - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; -}; - class FuncGraphJDirectCollector final : public CounterFuncGraphCollector { public: explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} @@ -433,7 +446,9 @@ class ScopeComputer final : public DepComputer { using FVTotalMap = OrderedMap>; -class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector, public CounterFuncGraphCollector { +class FVTotalComputer final : public DepComputer, + public CounterAnfNodeCollector, + public CounterFuncGraphCollector { public: explicit FVTotalComputer(const FuncGraphManager *m) : DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {} @@ -549,18 +564,18 @@ class FuncGraphManager : public std::enable_shared_from_this { FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; } - FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; } - FuncGraphToAnfNodeCounterMap &free_variables_direct() const { return free_variables_direct_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &free_variables_direct() const { + return free_variables_direct_->count_nodes_map_; + } - FuncGraphToAnfNodeCounterMap &func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &func_graph_cnodes_index() const { + return func_graph_cnodes_index_->count_nodes_map_; + } FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } - FuncGraphToFuncGraphCounterMap &func_graph_users() const { return func_graph_users_->count_func_graphs_map_; } - - FuncGraphToAnfNodeCounterMap &func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; } - FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const { return func_graph_child_direct_->count_func_graphs_map_; } @@ -598,10 +613,8 @@ class FuncGraphManager : public std::enable_shared_from_this { std::shared_ptr nodes_; std::shared_ptr valuenodes_; std::shared_ptr free_variables_direct_; - std::shared_ptr func_graph_valuenodes_; + std::shared_ptr func_graph_cnodes_index_; std::shared_ptr func_graphs_used_; - std::shared_ptr func_graph_users_; - std::shared_ptr func_graph_user_cnodes_; std::shared_ptr func_graph_child_direct_; std::shared_ptr func_graph_parents_direct_; std::shared_ptr func_graph_j_direct_; diff --git a/mindspore/ccsrc/optimizer/irpass/inline.h b/mindspore/ccsrc/optimizer/irpass/inline.h index a7b6b975bb8..8ebd0f6eb7e 100644 --- a/mindspore/ccsrc/optimizer/irpass/inline.h +++ b/mindspore/ccsrc/optimizer/irpass/inline.h @@ -81,10 +81,10 @@ bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) { } bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) { - auto &users = fg->func_graph_users(); + auto &cnodes = fg->func_graph_cnodes_index(); int n_use = - std::accumulate(users.begin(), users.end(), 0, - [](int sum, const std::pair &item) { return sum + item.second; }); + std::accumulate(cnodes.begin(), cnodes.end(), 0, + [](int sum, const std::pair &item) { return sum + item.second; }); return n_use == 1; } diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index b14bf548699..9147f75fb25 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -486,7 +486,8 @@ void CompileGraph::AddExternal(const LinConvertResult &result) { } void TraverseGraphMap( - const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphToAnfNodeCounterMap &cts, + const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, + const FuncGraphToAnfNodeCounterMap &cts, const std::function(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { MS_EXCEPTION_IF_NULL(manager_ptr); MS_EXCEPTION_IF_NULL(tr); diff --git a/tests/ut/cpp/ir/manager_test.cc b/tests/ut/cpp/ir/manager_test.cc index a7a19a7d24c..8816277c492 100644 --- a/tests/ut/cpp/ir/manager_test.cc +++ b/tests/ut/cpp/ir/manager_test.cc @@ -127,12 +127,18 @@ class NestingSpecs { return; } - auto counter_p = dynamic_pointer_cast(results); + auto counter_p = dynamic_pointer_cast>(results); if (counter_p != nullptr) { CheckAnfNodeCounter(counter_p); return; } + auto counter_pair = dynamic_pointer_cast>(results); + if (counter_pair != nullptr) { + CheckCNodeIndexPairCounter(counter_pair); + return; + } + auto nodes = dynamic_pointer_cast(results); if (nodes != nullptr) { CheckNodes(nodes); @@ -226,7 +232,7 @@ class NestingSpecs { // Add CheckNesting function - void CheckAnfNodeCounter(std::shared_ptr results) { + void CheckAnfNodeCounter(std::shared_ptr> results) { std::map> clean_results; for (auto& iter : results->count_nodes_map()) { auto key = iter.first; @@ -252,6 +258,32 @@ class NestingSpecs { ASSERT_EQ(clean_results, expected_); } + void CheckCNodeIndexPairCounter(std::shared_ptr> results) { + std::map> clean_results; + for (auto& iter : results->count_nodes_map()) { + auto key = iter.first; + auto value = iter.second; + if (key == nullptr) { + continue; + } + std::string k = Name(key); + + std::set v; + for (auto& node : value) { + auto fg = node.first->first; + if (!Name(fg).empty()) { + v.insert(Name(fg)); + } + } + + if (!v.empty()) { + clean_results[k] = v; + } + } + + ASSERT_EQ(clean_results, expected_); + } + void CheckGraphCounter(std::shared_ptr results) { std::map> clean_results; for (auto& iter : results->count_func_graphs_map()) { @@ -447,9 +479,8 @@ void TestManager::CheckAnalysisSize(std::shared_ptr mng) { ASSERT_EQ(size, mng->free_variables_total().size()); ASSERT_EQ(size, mng->valuenodes().size()); ASSERT_EQ(size, mng->free_variables_direct().size()); - ASSERT_EQ(size, mng->func_graph_valuenodes().size()); + ASSERT_EQ(size, mng->func_graph_cnodes_index().size()); ASSERT_EQ(size, mng->func_graph_parents_direct().size()); - ASSERT_EQ(size, mng->func_graph_users().size()); ASSERT_EQ(size, mng->func_graphs_used().size()); } @@ -508,10 +539,6 @@ TEST_F(TestManager, test_nested_manual) { ASSERT_EQ(1, graphs_used[f].size()); ASSERT_EQ(0, graphs_used[g].size()); - auto graph_users = mng->func_graph_users(); - ASSERT_EQ(0, graph_users[f].size()); - ASSERT_EQ(1, graph_users[g].size()); - auto fv_direct = mng->free_variables_direct(); ASSERT_EQ(0, fv_direct[f].size()); ASSERT_EQ(1, fv_direct[g].size()); @@ -520,9 +547,9 @@ TEST_F(TestManager, test_nested_manual) { ASSERT_EQ(0, fv_total[f].size()); ASSERT_EQ(1, fv_total[g].size()); - auto graph_valuenodes = mng->func_graph_valuenodes(); - ASSERT_EQ(0, graph_valuenodes[f].size()); - ASSERT_EQ(1, graph_valuenodes[g].size()); + auto cnodes = mng->func_graph_cnodes_index(); + ASSERT_EQ(0, cnodes[f].size()); + ASSERT_EQ(1, cnodes[g].size()); } TEST_F(TestManager, test_deep_nested2_manual) {