fix bug of control flow
This commit is contained in:
parent
fa7e23df86
commit
f56501175c
|
@ -25,10 +25,16 @@
|
||||||
#include "src/tensorlist.h"
|
#include "src/tensorlist.h"
|
||||||
#include "src/common/prim_inner.h"
|
#include "src/common/prim_inner.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
const constexpr int kMinNonTailCallCount = 2;
|
||||||
|
}
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
int ControlFlowScheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
int ControlFlowScheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
||||||
auto ret = this->IsolateSameInputPartials(dst_kernels);
|
auto ret = this->IsolateSameInputPartials(dst_kernels);
|
||||||
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateSameInputPartials failed.");
|
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateSameInputPartials failed.");
|
||||||
|
ret = this->IsolateOutputForCallOutputGraph(dst_kernels);
|
||||||
|
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateOutputForCallOutputGraph failed");
|
||||||
ret = this->IsolateInputOfMultipleCalledGraph(dst_kernels);
|
ret = this->IsolateInputOfMultipleCalledGraph(dst_kernels);
|
||||||
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateInputOfMultipleCalledGraph failed.");
|
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateInputOfMultipleCalledGraph failed.");
|
||||||
ret = this->BuildBoundaryForMultipleCalledGraph(dst_kernels);
|
ret = this->BuildBoundaryForMultipleCalledGraph(dst_kernels);
|
||||||
|
@ -37,8 +43,6 @@ int ControlFlowScheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernel
|
||||||
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "RecordControlFlowLinkInfo failed.");
|
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "RecordControlFlowLinkInfo failed.");
|
||||||
ret = this->RecordAllTailCallLinkInfo(dst_kernels);
|
ret = this->RecordAllTailCallLinkInfo(dst_kernels);
|
||||||
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "SplitNonTailCallSubGraphs failed");
|
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "SplitNonTailCallSubGraphs failed");
|
||||||
ret = this->IsolateOutputForCallOutputGraph(dst_kernels);
|
|
||||||
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateOutputForCallOutputGraph failed");
|
|
||||||
ret = this->SplitNonTailCallSubGraphs(dst_kernels);
|
ret = this->SplitNonTailCallSubGraphs(dst_kernels);
|
||||||
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "SplitNonTailCallSubGraphs failed");
|
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "SplitNonTailCallSubGraphs failed");
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -183,10 +187,26 @@ int ControlFlowScheduler::RecordNonTailCallLinkInfo() {
|
||||||
MS_CHECK_TRUE_MSG(!kernels.empty(), RET_ERROR, "partial subgraph kernels empty.");
|
MS_CHECK_TRUE_MSG(!kernels.empty(), RET_ERROR, "partial subgraph kernels empty.");
|
||||||
auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(kernels.back());
|
auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(kernels.back());
|
||||||
MS_CHECK_TRUE_MSG(subgraph != nullptr, RET_ERROR, "partial node's subgraph kernel is nullptr.");
|
MS_CHECK_TRUE_MSG(subgraph != nullptr, RET_ERROR, "partial node's subgraph kernel is nullptr.");
|
||||||
MS_CHECK_TRUE_MSG(subgraph->out_tensors().size() == non_tail_call_output_size, RET_ERROR,
|
if (kernel::LiteKernelUtil::IsTailCallSubGraph(subgraph)) {
|
||||||
"partial inputs and corresponding call outputs size not same.");
|
std::queue<kernel::LiteKernel *> tail_call_q{};
|
||||||
for (size_t i = 0; i < non_tail_call_output_size; ++i) {
|
tail_call_q.push(subgraph->out_nodes().front());
|
||||||
context_->SetLinkInfo(subgraph->out_tensors()[i], non_tail_call->out_tensors()[i]);
|
std::vector<kernel::LiteKernel *> final_graphs{};
|
||||||
|
std::set<kernel::LiteKernel *> reviewed_graphs{};
|
||||||
|
auto ret = GetTailCallFinalSubgraphs(&tail_call_q, &final_graphs, reviewed_graphs);
|
||||||
|
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "GetTailCallFinalSubgraphs failed.");
|
||||||
|
for (auto item : final_graphs) {
|
||||||
|
MS_CHECK_TRUE_MSG(item->out_tensors().size() == non_tail_call_output_size, RET_ERROR,
|
||||||
|
"subgraph outputs and corresponding call outputs size not same.");
|
||||||
|
for (size_t i = 0; i < non_tail_call_output_size; ++i) {
|
||||||
|
context_->SetLinkInfo(item->out_tensors()[i], non_tail_call->out_tensors()[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_CHECK_TRUE_MSG(subgraph->out_tensors().size() == non_tail_call_output_size, RET_ERROR,
|
||||||
|
"partial inputs and corresponding call outputs size not same.");
|
||||||
|
for (size_t i = 0; i < non_tail_call_output_size; ++i) {
|
||||||
|
context_->SetLinkInfo(subgraph->out_tensors()[i], non_tail_call->out_tensors()[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -322,10 +342,15 @@ int ControlFlowScheduler::GetSubGraphsWhichNeedBoundary() {
|
||||||
all_call_nodes.push_back(call_node);
|
all_call_nodes.push_back(call_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
// all of the caller is tail call, continue
|
// non-tail call size less than 2, continue
|
||||||
if (std::all_of(all_call_nodes.begin(), all_call_nodes.end(),
|
int non_tail_call_size = 0;
|
||||||
[](kernel::LiteKernel *call_node) { return kernel::LiteKernelUtil::IsTailCall(call_node); })) {
|
for (auto call_node : all_call_nodes) {
|
||||||
MS_LOG(DEBUG) << "graph is output graph and caller is tail call, no need to build boundary.";
|
if (kernel::LiteKernelUtil::IsNonTailCall(call_node)) {
|
||||||
|
non_tail_call_size++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (non_tail_call_size < kMinNonTailCallCount) {
|
||||||
|
MS_LOG(DEBUG) << "no need to build boundary.";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
for (auto partial_node : item.second) {
|
for (auto partial_node : item.second) {
|
||||||
|
|
Loading…
Reference in New Issue