diff --git a/mindspore/ccsrc/ir/manager.cc b/mindspore/ccsrc/ir/manager.cc index ffa42c91776..00b31d543ed 100644 --- a/mindspore/ccsrc/ir/manager.cc +++ b/mindspore/ccsrc/ir/manager.cc @@ -640,53 +640,14 @@ void FuncGraphTransaction::Commit() { } FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager) - : manager_(manager), include_func_graph_none_(false) { - manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph); - manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph); - manager_->signals()->AddEdge.connect(this, &FuncGraphAnalysis::OnAddEdge); - manager_->signals()->DropEdge.connect(this, &FuncGraphAnalysis::OnDropEdge); - manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode); -} - -NodesCollector::NodesCollector(const FuncGraphManager *const m) : DepCollector(m), nodes_analysis_() { - include_func_graph_none_ = true; - nodes_analysis_[nullptr] = AnfNodeSet(); - - manager_->signals()->AddNode.connect(this, &NodesCollector::OnAddNode); - manager_->signals()->DropNode.connect(this, &NodesCollector::OnDropNode); -} - -void NodesCollector::OnAddNode(AnfNodePtr n) { - if (nodes_analysis_.find(n->func_graph()) == nodes_analysis_.end()) { - nodes_analysis_[n->func_graph()] = AnfNodeSet(); - } - nodes_analysis_[n->func_graph()].add(n); -} - -void NodesCollector::OnDropNode(AnfNodePtr n) { - (void)nodes_analysis_[n->func_graph()].erase(n); - auto graph = n->func_graph(); - // Remove the node from order list. - if (graph) { - graph->EraseUnusedNodeInOrder(n); - } -} - -void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - // change the owner of node except for the src's return node - for (auto &it : nodes_analysis_[src]) { - nodes_analysis_[dst].add(it); - } - (void)nodes_analysis_.erase(src); -} - -void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); } + : manager_(manager), include_func_graph_none_(false) {} DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { MS_EXCEPTION_IF_NULL(manager_); - manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector); } +void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); } + void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } template @@ -735,65 +696,6 @@ bool CounterAnfNodeCollector::Mod(const F } } -void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - if (inp->isa()) { - (void)Mod(node->func_graph(), inp, direction); - } -} - -void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto &it : count_nodes_map_[src]) { - (void)Inc(dst, it.first, it.second); - } - (void)count_nodes_map_.erase(src); -} - -void FuncGraphUsersCNodeIndexCollector::OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, - EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - if (IsValueNode(inp)) { - (void)Mod(GetValueNode(inp), std::make_shared(std::make_pair(node, index)), - direction); - } -} - -void FuncGraphUsersCNodeIndexCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto &it : count_nodes_map_[src]) { - // Ignore the user graph who may own itself. - if (dst != it.first->first->func_graph()) { - (void)Inc(dst, it.first, it.second); - } - } - (void)count_nodes_map_.erase(src); -} - -void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(inp); - FuncGraphPtr fg1 = node->func_graph(); - FuncGraphPtr fg2 = inp->func_graph(); - if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) { - (void)Mod(fg1, inp, direction); - } -} - -void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto &it : count_nodes_map_[src]) { - FuncGraphPtr fg2 = it.first->func_graph(); - if (fg2 != dst) { - (void)Inc(dst, it.first, it.second); - } - } - (void)count_nodes_map_.erase(src); -} - -static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) { - FuncGraphPtr gn = std::make_shared(); - (void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg))); - return gn; -} - bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { auto &d = count_func_graphs_map_[func_graph]; if (d.count(key) == 0) { @@ -833,87 +735,6 @@ bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGr } } -void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(inp); - FuncGraphPtr fg1 = node->func_graph(); - FuncGraphPtr fg2 = inp->func_graph(); - if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) { - (void)Mod(fg2, fg1, direction); - } -} - -void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto &it : count_func_graphs_map_[src]) { - FuncGraphPtr fg = it.first; - if (fg != dst) { - (void)Inc(dst, fg, it.second); - } - } - (void)count_func_graphs_map_.erase(src); -} - -void FuncGraphParentsDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - FuncGraphPtr fg1 = node->func_graph(); - // possible child parent - if (IsValueNode(inp)) { - FuncGraphPtr fg2 = GetValueNode(inp); - if (Mod(fg1, ParentProxy(fg2), direction)) { - manager_->signals()->InvalidateComputer(); - } - } - // from fv - FuncGraphPtr fg2 = inp->func_graph(); - if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) { - // node use fv will in here, fg1's node use fg2's node, so fg1 is child and fg2 is parent - if (Mod(fg1, fg2, direction)) { - manager_->signals()->InvalidateComputer(); - } - } -} - -void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto &it : count_func_graphs_map_[src]) { - if (it.first != dst) { - (void)Inc(dst, it.first, it.second); - } - } - (void)count_func_graphs_map_.erase(src); -} - -void FuncGraphsUsedCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - if (IsValueNode(inp)) { - (void)Mod(node->func_graph(), GetValueNode(inp), direction); - } -} - -void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - // all graph use in src need to change to dst, so meger the to dst use - for (auto &it : count_func_graphs_map_[src]) { - (void)Inc(dst, it.first, it.second); - } - (void)count_func_graphs_map_[dst].erase(src); - (void)count_func_graphs_map_.erase(src); -} - -void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { - if (IsValueNode(inp) && IsPrimitiveCNode(node, prim::kPrimJ)) { - (void)Mod(node->func_graph(), GetValueNode(inp), direction); - MS_LOG(DEBUG) << node->func_graph()->ToString() << " users func graph " - << GetValueNode(inp)->ToString() << " which contains J(func_graph), dir: " << direction; - } -} - -void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - // all graph use in src need to change to dst, so meger the to dst use - for (auto &it : count_func_graphs_map_[src]) { - (void)Inc(dst, it.first, it.second); - } - (void)count_func_graphs_map_.erase(src); -} - DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { MS_EXCEPTION_IF_NULL(manager_); manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); diff --git a/mindspore/ccsrc/ir/manager.h b/mindspore/ccsrc/ir/manager.h index cc4336056e6..06b2859feac 100644 --- a/mindspore/ccsrc/ir/manager.h +++ b/mindspore/ccsrc/ir/manager.h @@ -140,44 +140,6 @@ class FuncGraphAnalysis { using FuncGraphToAnfNodeMap = OrderedMap; -// graphs analysis which compute in write, read needn't recompute -class DepCollector : public FuncGraphAnalysis { - public: - explicit DepCollector(const FuncGraphManager *manager); - ~DepCollector() override = default; - - void Reset() { ExtraReset(); } - void OnInvalidateCollector() { Reset(); } - - protected: - // inherit from FuncGraphAnalysis - void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override; - void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override; - // subclass can override; - virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {} -}; - -class NodesCollector final : public DepCollector { - public: - explicit NodesCollector(const FuncGraphManager *m); - ~NodesCollector() override = default; - - const FuncGraphToAnfNodeMap &nodes_analysis() const { return nodes_analysis_; } - size_t size() const override { return nodes_analysis_.size(); } - void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); } - - void OnDropFuncGraph(FuncGraphPtr fg) override { (void)nodes_analysis_.erase(fg); } - - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - - FuncGraphToAnfNodeMap nodes_analysis_; - - protected: - void ExtraReset() override { nodes_analysis_.clear(); } - void OnAddNode(AnfNodePtr n) override; - void OnDropNode(AnfNodePtr n) override; -}; - struct CNodeIndexHasher { std::size_t operator()(const CNodeIndexPairPtr pair) const { MS_EXCEPTION_IF_NULL(pair); @@ -204,59 +166,21 @@ struct CNodeIndexEqual { } }; -template , class CollectorEqual = std::equal_to> -class CounterAnfNodeCollector : public DepCollector { +// graphs analysis which compute in write, read needn't recompute +class DepCollector : public FuncGraphAnalysis { public: - explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} - ~CounterAnfNodeCollector() override = default; - FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; } + explicit DepCollector(const FuncGraphManager *manager); + ~DepCollector() override = default; - size_t size() const override { return count_nodes_map_.size(); } - void OnAddFuncGraph(FuncGraphPtr fg) final { - count_nodes_map_[fg] = OrderedMap(); - } - void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } - - bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count); - bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count); - bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count); - - FuncGraphToAnfNodeCounterMap count_nodes_map_; + void Reset() { ExtraReset(); } + void OnInvalidateCollector() { Reset(); } protected: - void ExtraReset() override { count_nodes_map_.clear(); } -}; - -class ValueNodesCollector final : public CounterAnfNodeCollector { - public: - explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} - ~ValueNodesCollector() override = default; - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; -}; - -// Record the CNode and its input index, who points to the function graph. -class FuncGraphUsersCNodeIndexCollector final - : public CounterAnfNodeCollector { - public: - explicit FuncGraphUsersCNodeIndexCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} - ~FuncGraphUsersCNodeIndexCollector() override = default; - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; -}; - -class FVDirectCollector final : public CounterAnfNodeCollector { - public: - explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} - ~FVDirectCollector() override = default; - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; + // inherit from FuncGraphAnalysis + void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override; + void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override; + // subclass can override; + virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {} }; class CounterFuncGraphCollector : public DepCollector { @@ -278,50 +202,27 @@ class CounterFuncGraphCollector : public DepCollector { void ExtraReset() override { count_func_graphs_map_.clear(); } }; -class FuncGraphChildDirect final : public CounterFuncGraphCollector { +template , class CollectorEqual = std::equal_to> +class CounterAnfNodeCollector : public DepCollector { public: - explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; + explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} + ~CounterAnfNodeCollector() override = default; + FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; } - ~FuncGraphChildDirect() override = default; + size_t size() const override { return count_nodes_map_.size(); } + void OnAddFuncGraph(FuncGraphPtr fg) final { + count_nodes_map_[fg] = OrderedMap(); + } + void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } + + bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count); + bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count); + bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count); + + FuncGraphToAnfNodeCounterMap count_nodes_map_; protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; -}; - -// graph's all parents, parentsdirect have a map, which key is graph, value is this graph's all direct and proxy -// parents: -// 1.proxy parent: graph g use graph f, key is g, value is ParentProxy(f) because f's parent will be g's parent -// 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f -class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector { - public: - explicit FuncGraphParentsDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} - ~FuncGraphParentsDirectCollector() override = default; - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; -}; - -// graph's all used graphs: key is g, value is g used graph -class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { - public: - explicit FuncGraphsUsedCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} - void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; - ~FuncGraphsUsedCollector() override = default; - - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; -}; - -class FuncGraphJDirectCollector final : public CounterFuncGraphCollector { - public: - explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} - void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override; - ~FuncGraphJDirectCollector() override = default; - - protected: - void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; + void ExtraReset() override { count_nodes_map_.clear(); } }; using FuncGraphToFuncGraphSetMap = OrderedMap; diff --git a/tests/ut/cpp/ir/manager_test.cc b/tests/ut/cpp/ir/manager_test.cc index 8816277c492..25e66036a18 100644 --- a/tests/ut/cpp/ir/manager_test.cc +++ b/tests/ut/cpp/ir/manager_test.cc @@ -132,18 +132,6 @@ class NestingSpecs { CheckAnfNodeCounter(counter_p); return; } - - auto counter_pair = dynamic_pointer_cast>(results); - if (counter_pair != nullptr) { - CheckCNodeIndexPairCounter(counter_pair); - return; - } - - auto nodes = dynamic_pointer_cast(results); - if (nodes != nullptr) { - CheckNodes(nodes); - return; - } } private: @@ -205,33 +193,7 @@ class NestingSpecs { ASSERT_EQ(clean_results, expected_); } - void CheckNodes(std::shared_ptr results) { - std::map> clean_results; - for (auto& iter : results->nodes_analysis()) { - auto key = iter.first; - auto value = iter.second; - if (key == nullptr) { - continue; - } - std::string k = Name(key); - - std::set v; - for (auto& node : value) { - if (!node->isa() && !Name(node).empty()) { - v.insert(Name(node)); - } - } - - if (!v.empty()) { - clean_results[k] = v; - } - } - - ASSERT_EQ(clean_results, expected_); - } - // Add CheckNesting function - void CheckAnfNodeCounter(std::shared_ptr> results) { std::map> clean_results; for (auto& iter : results->count_nodes_map()) { @@ -258,32 +220,6 @@ class NestingSpecs { ASSERT_EQ(clean_results, expected_); } - void CheckCNodeIndexPairCounter(std::shared_ptr> results) { - std::map> clean_results; - for (auto& iter : results->count_nodes_map()) { - auto key = iter.first; - auto value = iter.second; - if (key == nullptr) { - continue; - } - std::string k = Name(key); - - std::set v; - for (auto& node : value) { - auto fg = node.first->first; - if (!Name(fg).empty()) { - v.insert(Name(fg)); - } - } - - if (!v.empty()) { - clean_results[k] = v; - } - } - - ASSERT_EQ(clean_results, expected_); - } - void CheckGraphCounter(std::shared_ptr results) { std::map> clean_results; for (auto& iter : results->count_func_graphs_map()) { @@ -471,17 +407,10 @@ std::vector MakeNestedGraph2() { } // Add TestManager::CheckManager function to checkout the result - void TestManager::CheckAnalysisSize(std::shared_ptr mng) { auto size = mng->func_graphs().size(); - ASSERT_EQ(size + 1, mng->nodes().size()); ASSERT_EQ(size, mng->free_variables_total().size()); - ASSERT_EQ(size, mng->valuenodes().size()); - ASSERT_EQ(size, mng->free_variables_direct().size()); - ASSERT_EQ(size, mng->func_graph_cnodes_index().size()); - ASSERT_EQ(size, mng->func_graph_parents_direct().size()); - ASSERT_EQ(size, mng->func_graphs_used().size()); } TEST_F(TestManager, test_scalar_add_manual) { @@ -525,31 +454,26 @@ TEST_F(TestManager, test_nested_manual) { ASSERT_EQ(1, mng->roots().size()); CheckAnalysisSize(mng); - auto nodes = mng->nodes(); - ASSERT_EQ(3, nodes[nullptr].size()); - ASSERT_EQ(2, nodes[f].size()); - ASSERT_EQ(1, nodes[g].size()); + ASSERT_EQ(2, f->nodes().size()); + ASSERT_EQ(1, g->nodes().size()); auto users = mng->node_users(); for (auto& iter : users) { ASSERT_EQ(1, iter.second.size()); } - auto graphs_used = mng->func_graphs_used(); - ASSERT_EQ(1, graphs_used[f].size()); - ASSERT_EQ(0, graphs_used[g].size()); + ASSERT_EQ(1, f->func_graph_value_nodes().size()); + ASSERT_EQ(0, g->func_graph_value_nodes().size()); - auto fv_direct = mng->free_variables_direct(); - ASSERT_EQ(0, fv_direct[f].size()); - ASSERT_EQ(1, fv_direct[g].size()); + ASSERT_EQ(0, f->free_variables().size()); + ASSERT_EQ(1, g->free_variables().size()); auto fv_total = mng->free_variables_total(); ASSERT_EQ(0, fv_total[f].size()); ASSERT_EQ(1, fv_total[g].size()); - auto cnodes = mng->func_graph_cnodes_index(); - ASSERT_EQ(0, cnodes[f].size()); - ASSERT_EQ(1, cnodes[g].size()); + ASSERT_EQ(0, f->func_graph_cnodes_index().size()); + ASSERT_EQ(1, g->func_graph_cnodes_index().size()); } TEST_F(TestManager, test_deep_nested2_manual) { @@ -567,7 +491,7 @@ TEST_F(TestManager, test_deep_nested2_manual) { ASSERT_EQ(3, mng->func_graphs().size()); ASSERT_EQ(1, mng->roots().size()); - ASSERT_EQ(4, mng->nodes().size()); + ASSERT_EQ(4, gfn->nodes().size()); ASSERT_EQ(20, mng->all_nodes().size()); ASSERT_EQ(25, mng->node_users().size()); CheckAnalysisSize(mng); @@ -631,7 +555,6 @@ TEST_F(TestManager, test_deep_nested_manual) { ASSERT_EQ(3, mng->func_graphs().size()); ASSERT_EQ(1, mng->roots().size()); - ASSERT_EQ(4, mng->nodes().size()); ASSERT_EQ(20, mng->all_nodes().size()); CheckAnalysisSize(mng); } @@ -716,12 +639,12 @@ TEST_F(TestManager, test_drop_root) { FuncGraphPtr fg = getPyFun("ir_get_fn"); auto mng = Manage(fg); - const FuncGraphToAnfNodeMap& nodes = mng->nodes(); - ASSERT_TRUE(nodes.find(fg) != nodes.end()); + const auto &fgs = mng->func_graphs(); + ASSERT_TRUE(fgs.contains(fg)); FuncGraphSet s; s.add(fg); mng->MaybeDropFuncGraphs(s); - ASSERT_TRUE(nodes.find(fg) != nodes.end()); + ASSERT_TRUE(fgs.contains(fg)); } TEST_F(TestManager, test_keep_roots) { diff --git a/tests/ut/cpp/optimizer/cconv_test.cc b/tests/ut/cpp/optimizer/cconv_test.cc index 0b47c78cd32..8bd6957e85f 100644 --- a/tests/ut/cpp/optimizer/cconv_test.cc +++ b/tests/ut/cpp/optimizer/cconv_test.cc @@ -26,15 +26,14 @@ namespace mindspore { void CheckNoFreeVariables(FuncGraphPtr root) { auto mng = Manage(root); - for (auto &iter : mng->nodes()) { - auto g = iter.first; - auto nodes = iter.second; + for (auto &iter : mng->func_graphs()) { + auto g = iter; if (g == nullptr) { continue; } - ASSERT_TRUE(g->parent() == nullptr); + auto nodes = g->nodes(); for (auto &node : nodes) { ASSERT_EQ(node->func_graph(), g); auto cnode = node->cast();