diff --git a/mindspore/lite/src/control_flow/control_flow_scheduler.cc b/mindspore/lite/src/control_flow/control_flow_scheduler.cc index 926d11a86c9..41a4d27af2f 100644 --- a/mindspore/lite/src/control_flow/control_flow_scheduler.cc +++ b/mindspore/lite/src/control_flow/control_flow_scheduler.cc @@ -25,10 +25,16 @@ #include "src/tensorlist.h" #include "src/common/prim_inner.h" +namespace { +const constexpr int kMinNonTailCallCount = 2; +} + namespace mindspore::lite { int ControlFlowScheduler::Schedule(std::vector *dst_kernels) { auto ret = this->IsolateSameInputPartials(dst_kernels); MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateSameInputPartials failed."); + ret = this->IsolateOutputForCallOutputGraph(dst_kernels); + MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateOutputForCallOutputGraph failed"); ret = this->IsolateInputOfMultipleCalledGraph(dst_kernels); MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateInputOfMultipleCalledGraph failed."); ret = this->BuildBoundaryForMultipleCalledGraph(dst_kernels); @@ -37,8 +43,6 @@ int ControlFlowScheduler::Schedule(std::vector *dst_kernel MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "RecordControlFlowLinkInfo failed."); ret = this->RecordAllTailCallLinkInfo(dst_kernels); MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "SplitNonTailCallSubGraphs failed"); - ret = this->IsolateOutputForCallOutputGraph(dst_kernels); - MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateOutputForCallOutputGraph failed"); ret = this->SplitNonTailCallSubGraphs(dst_kernels); MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "SplitNonTailCallSubGraphs failed"); return ret; @@ -183,10 +187,26 @@ int ControlFlowScheduler::RecordNonTailCallLinkInfo() { MS_CHECK_TRUE_MSG(!kernels.empty(), RET_ERROR, "partial subgraph kernels empty."); auto subgraph = reinterpret_cast(kernels.back()); MS_CHECK_TRUE_MSG(subgraph != nullptr, RET_ERROR, "partial node's subgraph kernel is nullptr."); - MS_CHECK_TRUE_MSG(subgraph->out_tensors().size() == non_tail_call_output_size, RET_ERROR, - "partial inputs and corresponding call outputs size not same."); - for (size_t i = 0; i < non_tail_call_output_size; ++i) { - context_->SetLinkInfo(subgraph->out_tensors()[i], non_tail_call->out_tensors()[i]); + if (kernel::LiteKernelUtil::IsTailCallSubGraph(subgraph)) { + std::queue tail_call_q{}; + tail_call_q.push(subgraph->out_nodes().front()); + std::vector final_graphs{}; + std::set reviewed_graphs{}; + auto ret = GetTailCallFinalSubgraphs(&tail_call_q, &final_graphs, reviewed_graphs); + MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "GetTailCallFinalSubgraphs failed."); + for (auto item : final_graphs) { + MS_CHECK_TRUE_MSG(item->out_tensors().size() == non_tail_call_output_size, RET_ERROR, + "subgraph outputs and corresponding call outputs size not same."); + for (size_t i = 0; i < non_tail_call_output_size; ++i) { + context_->SetLinkInfo(item->out_tensors()[i], non_tail_call->out_tensors()[i]); + } + } + } else { + MS_CHECK_TRUE_MSG(subgraph->out_tensors().size() == non_tail_call_output_size, RET_ERROR, + "partial inputs and corresponding call outputs size not same."); + for (size_t i = 0; i < non_tail_call_output_size; ++i) { + context_->SetLinkInfo(subgraph->out_tensors()[i], non_tail_call->out_tensors()[i]); + } } } } @@ -322,10 +342,15 @@ int ControlFlowScheduler::GetSubGraphsWhichNeedBoundary() { all_call_nodes.push_back(call_node); } - // all of the caller is tail call, continue - if (std::all_of(all_call_nodes.begin(), all_call_nodes.end(), - [](kernel::LiteKernel *call_node) { return kernel::LiteKernelUtil::IsTailCall(call_node); })) { - MS_LOG(DEBUG) << "graph is output graph and caller is tail call, no need to build boundary."; + // non-tail call size less than 2, continue + int non_tail_call_size = 0; + for (auto call_node : all_call_nodes) { + if (kernel::LiteKernelUtil::IsNonTailCall(call_node)) { + non_tail_call_size++; + } + } + if (non_tail_call_size < kMinNonTailCallCount) { + MS_LOG(DEBUG) << "no need to build boundary."; continue; } for (auto partial_node : item.second) {