fix bug of control flow

This commit is contained in:
mengyuanli 2022-01-20 14:15:02 +08:00
parent fa7e23df86
commit f56501175c
1 changed files with 35 additions and 10 deletions

View File

@ -25,10 +25,16 @@
#include "src/tensorlist.h"
#include "src/common/prim_inner.h"
namespace {
const constexpr int kMinNonTailCallCount = 2;
}
namespace mindspore::lite {
int ControlFlowScheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
auto ret = this->IsolateSameInputPartials(dst_kernels);
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateSameInputPartials failed.");
ret = this->IsolateOutputForCallOutputGraph(dst_kernels);
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateOutputForCallOutputGraph failed");
ret = this->IsolateInputOfMultipleCalledGraph(dst_kernels);
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateInputOfMultipleCalledGraph failed.");
ret = this->BuildBoundaryForMultipleCalledGraph(dst_kernels);
@ -37,8 +43,6 @@ int ControlFlowScheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernel
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "RecordControlFlowLinkInfo failed.");
ret = this->RecordAllTailCallLinkInfo(dst_kernels);
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "SplitNonTailCallSubGraphs failed");
ret = this->IsolateOutputForCallOutputGraph(dst_kernels);
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateOutputForCallOutputGraph failed");
ret = this->SplitNonTailCallSubGraphs(dst_kernels);
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "SplitNonTailCallSubGraphs failed");
return ret;
@ -183,6 +187,21 @@ int ControlFlowScheduler::RecordNonTailCallLinkInfo() {
MS_CHECK_TRUE_MSG(!kernels.empty(), RET_ERROR, "partial subgraph kernels empty.");
auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(kernels.back());
MS_CHECK_TRUE_MSG(subgraph != nullptr, RET_ERROR, "partial node's subgraph kernel is nullptr.");
if (kernel::LiteKernelUtil::IsTailCallSubGraph(subgraph)) {
std::queue<kernel::LiteKernel *> tail_call_q{};
tail_call_q.push(subgraph->out_nodes().front());
std::vector<kernel::LiteKernel *> final_graphs{};
std::set<kernel::LiteKernel *> reviewed_graphs{};
auto ret = GetTailCallFinalSubgraphs(&tail_call_q, &final_graphs, reviewed_graphs);
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "GetTailCallFinalSubgraphs failed.");
for (auto item : final_graphs) {
MS_CHECK_TRUE_MSG(item->out_tensors().size() == non_tail_call_output_size, RET_ERROR,
"subgraph outputs and corresponding call outputs size not same.");
for (size_t i = 0; i < non_tail_call_output_size; ++i) {
context_->SetLinkInfo(item->out_tensors()[i], non_tail_call->out_tensors()[i]);
}
}
} else {
MS_CHECK_TRUE_MSG(subgraph->out_tensors().size() == non_tail_call_output_size, RET_ERROR,
"partial inputs and corresponding call outputs size not same.");
for (size_t i = 0; i < non_tail_call_output_size; ++i) {
@ -190,6 +209,7 @@ int ControlFlowScheduler::RecordNonTailCallLinkInfo() {
}
}
}
}
return RET_OK;
}
@ -322,10 +342,15 @@ int ControlFlowScheduler::GetSubGraphsWhichNeedBoundary() {
all_call_nodes.push_back(call_node);
}
// all of the caller is tail call, continue
if (std::all_of(all_call_nodes.begin(), all_call_nodes.end(),
[](kernel::LiteKernel *call_node) { return kernel::LiteKernelUtil::IsTailCall(call_node); })) {
MS_LOG(DEBUG) << "graph is output graph and caller is tail call, no need to build boundary.";
// non-tail call size less than 2, continue
int non_tail_call_size = 0;
for (auto call_node : all_call_nodes) {
if (kernel::LiteKernelUtil::IsNonTailCall(call_node)) {
non_tail_call_size++;
}
}
if (non_tail_call_size < kMinNonTailCallCount) {
MS_LOG(DEBUG) << "no need to build boundary.";
continue;
}
for (auto partial_node : item.second) {