diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc index 7d19cf65a0c..b9124449dd8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc @@ -693,6 +693,9 @@ void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vectorinputs().size(); ++input_idx) { auto input_node = cnode->input(input_idx); MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa() && AnfAlgo::GetInputTensorNum(input_node) == 0) { + continue; + } output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first); } } else { diff --git a/mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.cc b/mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.cc index 3102908f3c0..0ed7c6ca663 100644 --- a/mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.cc +++ b/mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.cc @@ -91,7 +91,7 @@ void InsertCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { } } -void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { +void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &func_output) { MS_EXCEPTION_IF_NULL(cnode); size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); for (size_t i = 0; i < output_num; i++) { @@ -102,6 +102,9 @@ void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, cnode, i); for (size_t j = 0; j < used_node_list->size(); j++) { auto used_node = used_node_list->at(j).first; + if (used_node != func_output) { + continue; + } auto used_node_index = static_cast(used_node_list->at(j).second - 1); auto cur_input = AnfAlgo::GetInputNode(utils::cast(used_node), used_node_index); const std::vector origin_shape = @@ -128,10 +131,11 @@ bool InsertCastCPU::Run(const FuncGraphPtr &func_graph) { } AnfNodePtrList outputs; kernel::GetFuncGraphOutputNodes(func_graph, &outputs); + auto func_output = func_graph->output(); for (auto node : outputs) { if (node != nullptr && node->isa() && AnfAlgo::IsRealKernel(node)) { auto cnode = node->cast(); - InsertCastForGraphOutput(func_graph, cnode); + InsertCastForGraphOutput(func_graph, cnode, func_output); } } return true;