diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index e080c862853..6bb00d5752e 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -386,9 +386,15 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K auto new_fg = BasicClone(fg); cnode_inputs.push_back(std::make_shared(new_fg)); } + auto origin_inputs = cnode->inputs(); + bool optimize_depend = false; + if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 && + origin_inputs[kRealInputIndexInDepend]->isa()) { + optimize_depend = true; + } // if has multiple depends,only select first depend as parameter - for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { - auto anf = cnode->inputs()[input_idx]; + for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { + auto anf = origin_inputs[input_idx]; MS_EXCEPTION_IF_NULL(anf); // anf has been created before if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { @@ -413,6 +419,9 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K (*other_graph_cnode)[anf] = new_parameter; } continue; + } else if (optimize_depend && input_idx == kDependAttachNodeIndex) { + cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); + continue; } else if (anf->isa()) { *from_other_graph = true; // the input node is a cnode from other graph diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index 9b2ee51b3fb..db275061343 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -28,6 +28,7 @@ #include #include "utils/log_adapter.h" +#include "utils/utils.h" #include "ir/manager.h" #include "ir/func_graph_cloner.h" #include "operator/ops.h" @@ -85,7 +86,6 @@ std::tuple TransformSegmentToAnfGr if (lst.empty()) { MS_LOG(EXCEPTION) << "Input anf node list is empty"; } - auto ref = [&eqv, &inputs, &fg](const AnfNodePtr &a) -> AnfNodePtr { if (a->isa() && !IsValueNode(a)) { eqv[a] = a; @@ -95,17 +95,14 @@ std::tuple TransformSegmentToAnfGr eqv[a]->set_abstract(a->abstract()); eqv[a]->set_kernel_info(a->kernel_info_ptr()); } - return eqv[a]; }; - // Merge CNodes into a AnfGraph that represents a linear instruction segment for (auto n : lst) { if (!n->isa()) { MS_LOG(EXCEPTION) << "Inst is not CNode"; } auto &inps = n->cast()->inputs(); - if (inps.empty()) { MS_LOG(EXCEPTION) << "Input is empty"; } @@ -114,21 +111,22 @@ std::tuple TransformSegmentToAnfGr inps[0]->cast()->value()->cast()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) { MS_LOG(EXCEPTION) << "Input[0] Must be a Primitive valuenode"; } - auto fn = inps[0]; - std::vector args{fn}; - (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref); - + if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() == 3 && inps[kRealInputIndexInDepend]->isa() && + eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) { + args.emplace_back(inps[kRealInputIndexInDepend]); + args.emplace_back(inps[kRealInputIndexInDepend]); + } else { + (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref); + } eqv[n] = fg->NewCNode(args); eqv[n]->set_abstract(n->abstract()); eqv[n]->set_kernel_info(n->kernel_info_ptr()); } - std::vector eqv_keys; (void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys), [](const std::pair &elem) -> AnfNodePtr { return elem.first; }); - auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys); AnfNodePtr fg_output; if (outputs.size() > 1) { diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index c1fba78be8d..3876f6279c5 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -136,29 +136,12 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map *n } } -bool IsGetItemNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - auto cnode = node->cast(); - auto &inputs = cnode->inputs(); - if (inputs.empty()) { - MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; - } - if (!IsValueNode(inputs[0])) { - return true; - } - PrimitivePtr node_prim = GetValueNode(inputs[0]); - return node_prim->name() == prim::kPrimTupleGetItem->name(); - } - return false; -} - -std::vector ReorderGetItemNode(const std::vector &nodes) { +std::vector OptimizeGetItemOrder(const std::vector &nodes) { std::vector result; std::map> insert_positions; std::map node_positions; for (auto &node : nodes) { - if (IsGetItemNode(node)) { + if (node->isa() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); auto &inputs = cnode->inputs(); @@ -241,7 +224,7 @@ std::vector SplitSort(const FuncGraphPtr &graph, const std::string & } } std::reverse(result.begin(), result.end()); - return ReorderGetItemNode(result); + return result; } } // namespace @@ -309,19 +292,12 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { return false; } -VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { +VectorRef CompileGraph::SplitNodesWithTarget(const std::vector &input_nodes, const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); + auto nodes = OptimizeGetItemOrder(input_nodes); VectorRef splits; VectorRef split; - auto nodes = TopoSort(graph->get_return()); - if (ContainMultiTarget(nodes)) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - std::string default_target = context_ptr->device_target(); - nodes = SplitSort(graph, default_target); - } std::string last_target; - MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); for (auto &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (IsCut(node)) { @@ -343,6 +319,36 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { return splits; } +VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + auto nodes = TopoSort(graph->get_return()); + MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); + + if (ContainMultiTarget(nodes)) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + std::string default_target = context_ptr->device_target(); + nodes = SplitSort(graph, default_target); + return SplitNodesWithTarget(nodes, graph); + } + + VectorRef splits; + VectorRef split; + for (auto &node : nodes) { + MS_EXCEPTION_IF_NULL(node); + if (IsCut(node)) { + if (split.size() != 0) { + splits.push_back(split); + } + splits.push_back(node); + split.clear(); + } else if (node->isa()) { + split.push_back(node); + } + } + return splits; +} + // Push the value node on the stack. void CompileGraph::Push(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); diff --git a/mindspore/ccsrc/vm/transform.h b/mindspore/ccsrc/vm/transform.h index f2d54198d60..a02478fc1ba 100644 --- a/mindspore/ccsrc/vm/transform.h +++ b/mindspore/ccsrc/vm/transform.h @@ -78,6 +78,7 @@ class CompileGraph { } private: + VectorRef SplitNodesWithTarget(const std::vector &input_nodes, const FuncGraphPtr &graph); void PushParameters(const FuncGraphPtr &func_graph); bool SplitGraph(const FuncGraphPtr &func_graph); int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = "");