diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc index 8df1909b552..7ac56078d98 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc @@ -18,6 +18,7 @@ #include #include #include "utils/utils.h" +#include "backend/optimizer/common/helper.h" #include "backend/optimizer/ascend/ascend_helper.h" #include "backend/session/anf_runtime_algorithm.h" #include "utils/ms_context.h" @@ -30,12 +31,12 @@ const BaseRef InsertTransOp::DefinePattern() const { return VectorRef({V, Xs}); } -bool IsGraphOutput(const AnfNodePtr &node, const std::vector &outputs) { +bool IsGraphOutput(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { + auto outputs = AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}); auto iter = std::find(outputs.begin(), outputs.end(), node); - if (iter != outputs.end()) { + if (iter != outputs.end() && GetRealNodeNum(func_graph, node) == 1) { return true; } - return false; } @@ -55,7 +56,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An MS_EXCEPTION_IF_NULL(ms_context); if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode && !ms_context->get_param(MS_CTX_ENABLE_PYNATIVE_HOOK)) { - if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) { + if (IsGraphOutput(node, func_graph)) { return new_node; } } diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index 1a43de26a54..8a7df6e9e3e 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -485,6 +485,12 @@ void RemoveNopNode(session::KernelGraph *const graph) { } } +size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node) { + auto out_list = GetRealNodeUsedList(graph, node); + MS_EXCEPTION_IF_NULL(out_list); + return out_list->size(); +} + std::shared_ptr>> GetRealNodeUsedList(const FuncGraphPtr &graph, const AnfNodePtr &node) { auto output_node_list = std::make_shared>>(); diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.h b/mindspore/ccsrc/backend/optimizer/common/helper.h index 9d64a88d020..7865050a35e 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.h +++ b/mindspore/ccsrc/backend/optimizer/common/helper.h @@ -172,6 +172,8 @@ bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); std::shared_ptr>> GetRealNodeUsedList(const FuncGraphPtr &graph, const AnfNodePtr &node); +size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node); + std::shared_ptr>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t output_index);