diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index 37b1dab6c7e..a79dd0933d3 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -97,7 +97,12 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr << MsContext::GetInstance()->get_param(MS_CTX_MAX_CALL_DEPTH) << ", please call 'context.set_context(max_call_depth=value)' to adjust this value."; } - const std::vector &all_nodes = TopoSort(func_node); + const auto &all_nodes = TopoSort(func_node, SuccIncoming, [&fg](const AnfNodePtr &node) -> IncludeType { + if (node->func_graph() != fg || node->isa()) { + return EXCLUDE; + } + return FOLLOW; + }); for (const auto &node : all_nodes) { AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << fg->ToString() diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 2ad81d61f05..00f9ef48fa3 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -156,6 +156,7 @@ class AnfNode : public Base { return os; } size_t seen_{0}; + size_t extra_seen_{0}; template void set_user_data(const std::string &key, const std::shared_ptr &value) { diff --git a/mindspore/core/ir/graph_utils.cc b/mindspore/core/ir/graph_utils.cc index ccdf8ee1d7a..6a9b17b128c 100644 --- a/mindspore/core/ir/graph_utils.cc +++ b/mindspore/core/ir/graph_utils.cc @@ -36,50 +36,42 @@ namespace mindspore { std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) { size_t seen = NewSeenGeneration(); std::deque todo(1024); - std::unordered_map rank; std::vector res; todo.clear(); todo.push_back(root); while (!todo.empty()) { AnfNodePtr node = todo.back(); - if (node == nullptr || node->seen_ == seen) { + if (node == nullptr) { todo.pop_back(); continue; } - if (rank.find(node) != rank.end() && rank[node] != todo.size()) { + if (node->extra_seen_ == seen) { // We use extra_seen_ as finish flag + todo.pop_back(); + continue; + } + auto incl = include(node); + if (node->seen_ == seen) { // We use seen_ as checking flag + todo.pop_back(); + if (incl != EXCLUDE) { + res.push_back(node); + } + node->extra_seen_ = seen; + continue; + } + if (node->seen_ == seen && node->extra_seen_ != seen) { MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(2); } - rank[node] = todo.size(); - bool cont = false; - auto incl = include(node); + node->seen_ = seen; if (incl == FOLLOW) { auto succs = succ(node); - for (const auto i : succs) { - if ((i != nullptr && i->seen_ != seen) - // Handle the case for 2 subgraphs calls each other. - // If the ValueNodeGraph's return is already in the todo list, do not follow it. - && !((std::find(todo.begin(), todo.end(), i) != todo.end()) && (i->func_graph() != nullptr) && - (i->func_graph()->get_return() == i))) { - todo.push_back(i); - cont = true; - } - } - } else if (incl == NOFOLLOW) { - // do nothing - } else if (incl == EXCLUDE) { - node->seen_ = seen; - todo.pop_back(); - continue; - } else { - MS_LOG(EXCEPTION) << "include(node) must return one of: \"follow\", \"nofollow\", \"exclude\""; + (void)std::copy_if(succs.begin(), succs.end(), std::back_inserter(todo), [seen](const AnfNodePtr &next) { + return next != nullptr && next->seen_ != seen && + (next->func_graph() == nullptr || next->func_graph()->get_return() != next); + }); + } else if (incl > EXCLUDE) { // Not NOFOLLOW or EXCLUDE + MS_LOG(EXCEPTION) << "The result of include(node) must be one of: \"follow\", \"nofollow\", \"exclude\""; } - if (cont) { - continue; - } - node->seen_ = seen; - res.push_back(node); - todo.pop_back(); } return res; }