diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 16ad0dd4965..8f43ce27cd4 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -817,7 +817,7 @@ void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &g } // construct graph include one op - auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); + auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask, true); MS_EXCEPTION_IF_NULL(graph); opt::RunOpAscendBackendIRFusionOptimization(graph); // kernel select diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index ca6c351de9f..5b749fc871e 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1569,7 +1569,8 @@ void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr std::shared_ptr SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info, const std::vector &input_tensors, - const std::vector &tensors_mask) { + const std::vector &tensors_mask, + bool is_ascend) { auto graph = std::make_shared(); graph->set_graph_id(graph_sum_); graph_sum_++; @@ -1612,7 +1613,11 @@ std::shared_ptr SessionBasic::ConstructSingleOpGraph(const OpRunInf graph->set_execution_order(exe_order); graph->UpdateGraphDynamicAttr(); // set output - CreateOutputNode(cnode, graph); + if (is_ascend) { + graph->set_output(cnode); + } else { + CreateOutputNode(cnode, graph); + } graph->SetInputNodes(); auto manager = MakeManager({graph}); if (manager != nullptr) { diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index d3b986531b8..6f45f1d0a94 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -180,7 +180,7 @@ class SessionBasic : public std::enable_shared_from_this { // create a single run op graph std::shared_ptr ConstructSingleOpGraph(const OpRunInfo &op_run_info, const std::vector &input_tensors, - const std::vector &tensors_mask); + const std::vector &tensors_mask, bool is_ascend = false); // create a new kernel graph and update the graph sum KernelGraphPtr NewKernelGraph(); std::vector CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph);