From e6a5fc0739d97989e6c369dee5b2ab20dbd36bd7 Mon Sep 17 00:00:00 2001 From: lingyunli63 Date: Thu, 19 Nov 2020 20:22:56 +0800 Subject: [PATCH] consider controldepend edges in checkcircle --- .../graph_kernel/basic_ops_fusion.cc | 25 ++- .../graph_kernel/composite_ops_fusion.cc | 59 +++++-- .../graph_kernel/composite_ops_fusion.h | 10 +- .../graph_kernel/graph_kernel_helper.cc | 145 ++++++++++++++++-- .../graph_kernel/graph_kernel_helper.h | 22 ++- .../backend/session/anf_runtime_algorithm.cc | 41 +++++ .../backend/session/anf_runtime_algorithm.h | 3 + .../ccsrc/backend/session/kernel_graph.cc | 45 +----- 8 files changed, 264 insertions(+), 86 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc index 14d4a42af74..0b9fa666bae 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc @@ -47,12 +47,13 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNode return is_fusable ? FOLLOW : EXCLUDE; } -std::vector FindFuseCNodes(const CNodePtr &cnode) { +std::vector FindFuseCNodes(const CNodePtr &cnode, + const std::multimap> &dep_pri) { // Search fusable nodes according input direction. auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1); auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); if (used_nodes.size() > 1) { - used_nodes = RemoveCircle(used_nodes, false); + used_nodes = RemoveCircle(used_nodes, dep_pri, false); } TopoSortForNodeList(&used_nodes); return used_nodes; @@ -78,7 +79,8 @@ void SearchForDependNode(const AnfNodeSet &outputs_set, const AnfNodeIndexSet &u } bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_outputs, const FuncGraphManagerPtr &mng, - std::unordered_map *eqv) { + std::unordered_map *eqv, + std::multimap> *depend_prior) { AnfNodeSet outputs_set; for (auto out : *outputs) { outputs_set.insert(out); @@ -112,6 +114,7 @@ bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_out auto new_control_depend = control_depend_node->func_graph()->NewCNode(new_control_depend_inputs); mng->Replace(control_depend_node, new_control_depend); has_erase_outs = true; + UpdateControlDependNode(depend_prior, control_depend_node, new_control_depend); } } else { it++; @@ -120,7 +123,8 @@ bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_out return has_erase_outs; } -void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) { +void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng, + std::multimap> *depend_prior) { AnfNodePtrList vir_outputs; std::unordered_map eqv; auto fg_outputs = fg->output(); @@ -137,7 +141,7 @@ void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, con MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output"; } - if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv)) { + if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv, depend_prior)) { return; } @@ -159,6 +163,11 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector *fused_ops) { bool changed = false; auto mng = kernel_graph->manager(); + + // depend_prior[depend] = pair(prior, controlDependNode) + std::multimap> depend_prior; + InitDependPrior(todos, &depend_prior); + for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { auto node = (*iter)->cast(); if (node == nullptr) { @@ -172,7 +181,7 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector -#include #include -#include #include -#include +#include #include +#include +#include +#include +#include #include #include "frontend/operator/ops.h" @@ -97,15 +98,29 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const AnfNod } bool CheckCircle(const std::set &fused_op_set, const AnfNodePtr &check_node, - std::set *cached_unconnected_set, std::vector *circle_nodes) { + std::set *cached_unconnected_set, std::vector *circle_nodes, + const std::multimap> &depend_prior) { if (!check_node->isa() || !fused_op_set.count(check_node)) { return false; } circle_nodes->clear(); + auto InputEdges = [&depend_prior](CNodePtr cnode) { + std::set edges; + auto range = depend_prior.equal_range(cnode); + for (auto iter = range.first; iter != range.second; ++iter) { + edges.insert(iter->second.first); + } + auto inputs = cnode->inputs(); + for (auto input : inputs) { + edges.insert(input); + } + return edges; + }; + std::set cached_done_set; auto cnode = check_node->cast(); - const auto &inputs = cnode->inputs(); + const auto &inputs = InputEdges(cnode); // there is a input not in fused_op_set, but the input depends on the fused_op_set for (auto input : inputs) { if (input->isa() && !fused_op_set.count(input)) { @@ -128,7 +143,7 @@ bool CheckCircle(const std::set &fused_op_set, const AnfNodePtr &che if (node->isa()) { auto cnode_ptr = node->cast(); - for (auto it : cnode_ptr->inputs()) { + for (auto it : InputEdges(cnode_ptr)) { if (it->isa()) { todos.push_back(it); } @@ -148,7 +163,9 @@ bool CheckCircle(const std::set &fused_op_set, const AnfNodePtr &che return !circle_nodes->empty(); } -std::vector RemoveCircle(const std::vector &fused_op, bool is_backward) { +std::vector RemoveCircle(const std::vector &fused_op, + const std::multimap> &depend_prior, + bool is_backward) { std::set cached_unconnected_set; std::set fused_op_set(fused_op.begin(), fused_op.end()); auto include = [&fused_op_set](const AnfNodePtr &node) { @@ -161,7 +178,7 @@ std::vector RemoveCircle(const std::vector &fused_op, bo std::vector circle_nodes; for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) { circle_nodes.clear(); - bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes); + bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes, depend_prior); // delete the circle node and the node which depend on the circle node in fused op if (has_circle) { auto mng = (*iter)->func_graph()->manager(); @@ -294,7 +311,8 @@ void TopoSortForNodeList(std::vector *lst) { lst->assign(res.begin(), res.end()); } -std::vector FindFuseCNodes(const CNodePtr &cnode) { +std::vector FindFuseCNodes(const CNodePtr &cnode, + const std::multimap> &dep_pri) { auto func_graph = cnode->func_graph(); auto mng = func_graph->manager(); // Search fusable nodes according input direction. @@ -307,7 +325,7 @@ std::vector FindFuseCNodes(const CNodePtr &cnode) { used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end()); if (used_nodes.size() > 1) { - used_nodes = RemoveCircle(used_nodes); + used_nodes = RemoveCircle(used_nodes, dep_pri); } used_nodes = RemoveWildGetitem(used_nodes); TopoSortForNodeList(&used_nodes); @@ -316,8 +334,18 @@ std::vector FindFuseCNodes(const CNodePtr &cnode) { bool FuseCompositeOps(const std::shared_ptr &kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); + auto mng = kernel_graph->manager(); + if (mng == nullptr) { + mng = Manage(kernel_graph, true); + kernel_graph->set_manager(mng); + } + auto todos = TopoSort(kernel_graph->get_return()); + std::reverse(todos.begin(), todos.end()); + + std::multimap> depend_prior; + InitDependPrior(todos, &depend_prior); + bool changed = false; - auto &todos = kernel_graph->execution_order(); for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { auto node = *iter; if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) { @@ -333,13 +361,16 @@ bool FuseCompositeOps(const std::shared_ptr &kernel_graph) } } - auto fuse_nodes = FindFuseCNodes(node); + auto fuse_nodes = FindFuseCNodes(node->cast(), depend_prior); if (fuse_nodes.size() <= 1) { continue; } changed = true; - FuseNodesToSubGraph(fuse_nodes, kernel_graph, ""); + AnfNodePtr fused_new_node; + AnfNodePtrList old_outputs; + std::tie(fused_new_node, old_outputs) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, ""); + ReplaceNewFuseCNodeForDependPrior(&depend_prior, fused_new_node, old_outputs); } return changed; } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.h index ea2fcb2a2aa..e669a37692b 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.h @@ -16,11 +16,13 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_ +#include +#include +#include #include #include +#include #include -#include -#include #include "backend/optimizer/common/optimizer.h" #include "backend/session/kernel_graph.h" @@ -29,7 +31,9 @@ namespace opt { const std::set graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", "LambNextMV", "LambUpdateWithLR"}; -std::vector RemoveCircle(const std::vector &fused_op, bool is_backward = true); +std::vector RemoveCircle(const std::vector &fused_op, + const std::multimap> &depend_prior, + bool is_backward = true); void TopoSortForNodeList(std::vector *lst); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index e926207f2d4..25f6dc002a3 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -14,20 +14,24 @@ * limitations under the License. */ #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" + #include +#include #include #include -#include "pipeline/jit/parse/python_adapter.h" -#include "pipeline/jit/action.h" +#include + #include "backend/kernel_compiler/common_utils.h" -#include "backend/session/anf_runtime_algorithm.h" -#include "vm/segment_runner.h" #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" #include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h" #include "backend/kernel_compiler/kernel.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/pass/const_input_to_attr_registry.h" #include "ir/func_graph_cloner.h" #include "ir/func_graph.h" -#include "backend/optimizer/pass/const_input_to_attr_registry.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/action.h" +#include "vm/segment_runner.h" #if ENABLE_GPU #include "runtime/device/gpu/kernel_info_setter.h" #endif @@ -526,12 +530,9 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_f } } -void FuseNodesToSubGraph(const std::vector &fuse_nodes, - const std::shared_ptr &kernel_graph, const std::string &postfix) { - if (fuse_nodes.empty()) { - return; - } - +std::tuple FuseNodesToSubGraph(const std::vector &fuse_nodes, + const std::shared_ptr &kernel_graph, + const std::string &postfix) { auto mng = kernel_graph->manager(); if (mng == nullptr) { mng = Manage(kernel_graph, true); @@ -565,6 +566,8 @@ void FuseNodesToSubGraph(const std::vector &fuse_nodes, } fuse_op_name += postfix; fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name)); + + return std::make_tuple(fuse_new_node, src_outputs); } bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc, @@ -737,7 +740,7 @@ std::vector GetFusibleOpList() { prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, - prim::kPrimTranspose, prim::kPrimCast}; + prim::kPrimTranspose, prim::kPrimCast, prim::kPrimRealDiv}; #elif ENABLE_GPU std::vector fusible_basic_ops = { prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, @@ -786,5 +789,123 @@ void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) { device::gpu::SetKernelInfo(cnode, kernel_type); #endif } + +void InitDependPrior(const std::vector &todos, + std::multimap> *depend_prior) { + for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { + auto cnode = (*iter)->cast(); + if (cnode == nullptr) { + continue; + } + if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimControlDepend)) { + continue; + } + + auto prior_node = cnode->input(kControlDependPriorIndex); + auto depend_node = cnode->input(kControlDependBehindIndex); + MS_EXCEPTION_IF_NULL(prior_node); + MS_EXCEPTION_IF_NULL(depend_node); + std::vector prior_nodes = {prior_node}; + std::vector depend_nodes = {depend_node}; + + int depend_mode = 0; + if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { + depend_mode = AnfAlgo::GetNodeAttr(cnode, kControlDependMode); + } + + auto GetOutputNodes = [cnode](const AnfNodePtr ¶m) -> std::vector { + std::vector out_nodes; + auto user_set = param->func_graph()->manager()->node_users()[param]; + for (auto iter = user_set.cbegin(); iter != user_set.cend(); ++iter) { + if (iter->first != cnode) { + out_nodes.push_back(iter->first); + } + } + return out_nodes; + }; + + if (prior_node->isa() && depend_mode == 1) { + prior_nodes = GetOutputNodes(prior_node); + } + if (depend_node->isa()) { + depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector{}; + } + + std::vector real_prior_nodes; + std::set prior_visited; + for (const auto &tmp : prior_nodes) { + AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); + } + prior_visited.clear(); + std::vector real_depend_nodes; + std::set depend_visited; + for (const auto &tmp : depend_nodes) { + AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); + } + depend_visited.clear(); + + for (auto &prior : real_prior_nodes) { + if (AnfAlgo::CheckPrimitiveType(prior, prim::kPrimControlDepend)) { + continue; + } + for (auto &depend : real_depend_nodes) { + if (AnfAlgo::CheckPrimitiveType(depend, prim::kPrimControlDepend)) { + continue; + } + depend_prior->insert({depend, std::make_pair(prior, cnode)}); + } + } + real_prior_nodes.clear(); + real_depend_nodes.clear(); + } +} + +void UpdateControlDependNode(std::multimap> *depend_prior, + const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend) { + for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) { + if (iter->second.second == control_depend_node) { + iter = depend_prior->erase(iter); + continue; + } + ++iter; + } + + std::multimap> new_depend_prior; + InitDependPrior(std::vector{new_control_depend}, &new_depend_prior); + for (auto item : new_depend_prior) { + depend_prior->insert(item); + } +} + +void ReplaceNewFuseCNodeForDependPrior(std::multimap> *depend_prior, + const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) { + std::multimap> new_fuse_cnode_dep_pri; + + for (size_t out_idx = 0; out_idx < outputs.size(); ++out_idx) { + if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimMakeTuple)) { + MS_LOG(ERROR) << "Need real outputs of makeTuple"; + } + if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimTupleGetItem)) { + continue; + } + for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) { + if (iter->first == outputs[out_idx]) { + new_fuse_cnode_dep_pri.insert({new_fuse_cnode, iter->second}); + iter = depend_prior->erase(iter); + continue; + } + if (iter->second.first == outputs[out_idx]) { + new_fuse_cnode_dep_pri.insert({iter->first, std::make_pair(new_fuse_cnode, iter->second.second)}); + iter = depend_prior->erase(iter); + continue; + } + ++iter; + } + } + + for (auto item : new_fuse_cnode_dep_pri) { + depend_prior->insert(item); + } +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h index f14696301f0..b75f8efc2c4 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h @@ -15,17 +15,20 @@ */ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_ -#include -#include -#include + #include +#include +#include +#include #include #include -#include +#include +#include #include "ir/anf.h" #include "ir/func_graph.h" #include "backend/session/kernel_graph.h" #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" +#include namespace mindspore { namespace opt { @@ -48,8 +51,9 @@ AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphP const AnfNodePtrList &outputs); void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs); -void FuseNodesToSubGraph(const std::vector &fuse_nodes, - const std::shared_ptr &kernel_graph, const std::string &postfix); +std::tuple FuseNodesToSubGraph(const std::vector &fuse_nodes, + const std::shared_ptr &kernel_graph, + const std::string &postfix); bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc); bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc, std::map *address_node_map); @@ -60,6 +64,12 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p std::vector GetFusibleOpList(); bool IsBasicFuseOp(const AnfNodePtr &node); void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); +void InitDependPrior(const std::vector &todos, + std::multimap> *depend_prior); +void UpdateControlDependNode(std::multimap> *depend_prior, + const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend); +void ReplaceNewFuseCNodeForDependPrior(std::multimap> *depend_prior, + const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs); } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_ diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index ebe8fc39b21..bbab8937905 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -1417,5 +1417,46 @@ std::vector AnfRuntimeAlgorithm::GetOutputRealDeviceShapeIfExist(const A } return device_shape; } + +void AnfRuntimeAlgorithm::GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector *result, + std::set *visited) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(result); + MS_EXCEPTION_IF_NULL(visited); + if (visited->find(anf_node) != visited->end()) { + MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited"; + return; + } + visited->insert(anf_node); + if (AnfAlgo::IsRealKernel(anf_node)) { + result->emplace_back(anf_node); + return; + } + if (!anf_node->isa()) { + return; + } + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().empty()) { + MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString(); + } + auto input0 = cnode->input(0); + if (IsPrimitive(input0, prim::kPrimMakeTuple)) { + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + GetAllFatherRealNode(cnode->input(i), result, visited); + } + } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { + if (cnode->inputs().size() != kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; + } + GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited); + } else if (IsPrimitive(input0, prim::kPrimDepend)) { + if (cnode->inputs().size() != kDependInputSize) { + MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!"; + } + GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited); + GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited); + } +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index 3f9ef917e3a..56c039ba2bd 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -232,6 +232,9 @@ class AnfRuntimeAlgorithm { static bool IsNodeDynamicShape(const AnfNodePtr &node); static std::vector GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); static std::vector GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); + // Find control_depend real input nodes. + static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector *result, + std::set *visited); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 7b416611470..1a7aeacc7c5 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -725,47 +725,6 @@ std::vector KernelGraph::GetOutputNodes(const AnfNodePtr &node) { return output_nodes; } -// Find control_depend real input nodes. -void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector *result, std::set *visited) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(result); - MS_EXCEPTION_IF_NULL(visited); - if (visited->find(anf_node) != visited->end()) { - MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited"; - return; - } - visited->insert(anf_node); - if (AnfAlgo::IsRealKernel(anf_node)) { - result->emplace_back(anf_node); - return; - } - if (!anf_node->isa()) { - return; - } - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().empty()) { - MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString(); - } - auto input0 = cnode->input(0); - if (IsPrimitive(input0, prim::kPrimMakeTuple)) { - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - GetAllFatherRealNode(cnode->input(i), result, visited); - } - } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { - if (cnode->inputs().size() != kTupleGetItemInputSize) { - MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; - } - GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited); - } else if (IsPrimitive(input0, prim::kPrimDepend)) { - if (cnode->inputs().size() != kDependInputSize) { - MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!"; - } - GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited); - GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited); - } -} - // update the depend relations of control depend void KernelGraph::UpdateControlDependRelations(const std::vector &depends) { for (const auto &node : depends) { @@ -800,12 +759,12 @@ void KernelGraph::UpdateControlDependRelations(const std::vector &de std::vector real_prior_nodes; std::set prior_visited; for (const auto &tmp : prior_nodes) { - GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); + AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); } std::vector real_depend_nodes; std::set depend_visited; for (const auto &tmp : depend_nodes) { - GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); + AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); } UpdateNodeInputOutputEdges(real_prior_nodes, real_depend_nodes); }