diff --git a/mindspore/ccsrc/transform/graph_ir/convert.cc b/mindspore/ccsrc/transform/graph_ir/convert.cc index 5b6e36517c1..cfeef8ce19b 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.cc +++ b/mindspore/ccsrc/transform/graph_ir/convert.cc @@ -217,6 +217,20 @@ std::vector GetAnfCallInputs(bool is_kernel_graph, const CNodePtr &c } return inputs; } + +bool HasSubgraph(const std::shared_ptr &func_graph) { + auto node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto sub_graph = GetCNodeFuncGraph(node); + if (sub_graph != nullptr) { + return true; + } + } + return false; +} } // namespace // ---------------implement of DfGraphConvertor------------- @@ -1543,7 +1557,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph(const std::string &name) { MS_LOG(INFO) << "Set graph input num: " << inputs.size(); (void)df_graph_->SetInputs(inputs); - SetGraphOutputs(); + SetGraphOutputs(true); (void)df_graph_->SetOutputs(graph_outputs_); IdentityOptimization(); @@ -1560,15 +1574,27 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph(const std::string &name) { return *this; } -void DfGraphConvertor::SetGraphOutputs() { +void DfGraphConvertor::SetGraphOutputs(bool is_main_graph) { if (cur_while_node_ == nullptr) { graph_outputs_.clear(); + std::vector return_nodes; auto ret_node = anf_graph_->get_return(); - auto adpt = FindAdapter(ret_node, training_); - MS_EXCEPTION_IF_NULL(adpt); - auto handles = adpt->getOutputs(Convert(ret_node)); - for (const auto &handle : handles) { - (void)graph_outputs_.emplace_back(std::make_pair(*handle.op, handle.out)); + // replace return node with graph output node. + if (!HasSubgraph(anf_graph_) && is_main_graph) { + auto output_nodes = ret_node->inputs(); + return_nodes.insert(return_nodes.end(), output_nodes.begin() + 1, output_nodes.end()); + } else { + return_nodes.push_back(ret_node); + } + for (size_t i = 0; i < return_nodes.size(); i++) { + auto output_node = return_nodes[i]; + MS_EXCEPTION_IF_NULL(output_node); + auto adpt = FindAdapter(output_node, training_); + MS_EXCEPTION_IF_NULL(adpt); + auto handles = adpt->getOutputs(Convert(output_node)); + for (const auto &handle : handles) { + (void)graph_outputs_.emplace_back(std::make_pair(*handle.op, handle.out)); + } } } diff --git a/mindspore/ccsrc/transform/graph_ir/convert.h b/mindspore/ccsrc/transform/graph_ir/convert.h index c377b2056a9..84025de1bb4 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.h +++ b/mindspore/ccsrc/transform/graph_ir/convert.h @@ -250,7 +250,7 @@ class DfGraphConvertor { void SetDynamicInputHandleByMultiInput(const OpAdapterPtr &adpt, const CNodePtr &node, const CNodePtr &from_node_input); void SetNodeControlInput(const AnfNodePtr &node, const AnfNodePtr &input); - void SetGraphOutputs(); + void SetGraphOutputs(bool is_main_graph = false); std::vector GetInputHandles(const AnfNodePtr &node, const AnfNodePtr &input); // Identity Optimization