diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc index 101c52d45cb..1fe3691e075 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/eliminate_redundant_output.cc @@ -70,8 +70,7 @@ bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr bool changed = false; for (const auto &user : users) { if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) { - // Sometime, the user of MakeTuple is not a TupleGetItem, but a UpdateState. - continue; + MS_LOG(EXCEPTION) << "User of MakeTuple should be GetItem, but got: " << user.first->DebugString(); } auto &getitem = user.first; auto idx = GetIndex(getitem); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc index 9984ce4e074..2b576b1b0b5 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc @@ -32,6 +32,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" #include "backend/optimizer/pass/getitem_tuple.h" +#include "backend/optimizer/graph_kernel/update_state_formatter.h" namespace mindspore { namespace opt { @@ -481,6 +482,7 @@ void GraphKernelCluster::Init(const FuncGraphPtr &func_graph) { } bool GraphKernelCluster::Run(const FuncGraphPtr &func_graph) { + (void)std::make_shared()->Run(func_graph); auto mng = func_graph->manager(); MS_EXCEPTION_IF_NULL(mng); Init(func_graph); @@ -493,6 +495,7 @@ bool GraphKernelCluster::Run(const FuncGraphPtr &func_graph) { mng->KeepRoots({func_graph}); } Clean(); + (void)std::make_shared()->Run(func_graph); return changed; } } // namespace opt diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc index 6cf1a56e689..95564759272 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc @@ -36,6 +36,7 @@ #include "frontend/operator/ops.h" #include "ir/func_graph_cloner.h" #include "vm/segment_runner.h" +#include "backend/optimizer/graph_kernel/update_state_formatter.h" namespace mindspore { namespace opt { @@ -710,6 +711,7 @@ bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector bool ParallelOpFusion::Run(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); + (void)std::make_shared()->Run(graph); auto kernel_graph = graph->cast>(); MS_EXCEPTION_IF_NULL(kernel_graph); @@ -724,7 +726,9 @@ bool ParallelOpFusion::Run(const FuncGraphPtr &graph) { auto parallel_infos = SearchFusableParallelCNodes(groups); // Create core-fuse subgraph and change origin graph. - return CreateParallelOpSubGraphs(parallel_infos, kernel_graph); + bool changed = CreateParallelOpSubGraphs(parallel_infos, kernel_graph); + (void)std::make_shared()->Run(graph); + return changed; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.cc index 5078abc23ff..4fe79033ac2 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.cc @@ -50,7 +50,8 @@ AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) { return result; } -AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdate(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph) { +AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdateState(const AnfNodePtrList &nodes, + const FuncGraphPtr &func_graph) { AnfNodePtrList result; for (auto node : nodes) { if (node->abstract()->isa()) { @@ -65,7 +66,6 @@ AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdate(const AnfNodePtrList &nod auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); MS_EXCEPTION_IF_NULL(tuple_getitem); - tuple_getitem->set_fullname_with_scope(node->fullname_with_scope() + "_TupleGetItem_" + std::to_string(i)); tuple_getitem->set_abstract(node_abstract[i]); tuple_getitem->set_kernel_info(std::make_shared()); result.push_back(tuple_getitem); @@ -76,30 +76,29 @@ AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdate(const AnfNodePtrList &nod } return result; } + bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) { auto todos = GetUpdateStateList(func_graph); bool changed = false; + auto mng = func_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); for (auto node : todos) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (cnode->size() <= kUpdateStateRealInput) continue; auto inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput); - // extend inputs of update if which have multiple outputs - inputs = ExtendInputsOfUpdate(inputs, func_graph); - if (inputs.size() + 2 != cnode->size() || inputs[0] != cnode->input(2)) { - AnfNodePtrList node_inputs = {cnode->input(0), cnode->input(1)}; + // extend inputs of UpdateState if which have multiple outputs + inputs = ExtendInputsOfUpdateState(inputs, func_graph); + if (inputs.size() + kUpdateStateRealInput != cnode->size() || inputs[0] != cnode->input(kUpdateStateRealInput)) { + AnfNodePtrList node_inputs = {cnode->input(kAnfPrimitiveIndex), cnode->input(kUpdateStateStateInput)}; node_inputs.insert(node_inputs.end(), inputs.begin(), inputs.end()); - cnode->set_inputs(node_inputs); + // Create a new UpdateState + auto new_node = func_graph->NewCNode(node_inputs); + new_node->set_abstract(node->abstract()); + mng->Replace(node, new_node); changed = true; } } - - if (changed) { - auto mng = func_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - mng->RemoveRoots(); - mng->KeepRoots({func_graph}); - } return changed; } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.h index ef3be568e38..a62a1e16b47 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.h @@ -39,7 +39,7 @@ class SpreadUpdateState : public Pass { public: SpreadUpdateState() : Pass("spread_update_state") {} ~SpreadUpdateState() override = default; - AnfNodePtrList ExtendInputsOfUpdate(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph); + AnfNodePtrList ExtendInputsOfUpdateState(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph); bool Run(const FuncGraphPtr &func_graph) override; };