!29329 [MS][LITE]fix bug of control model
Merge pull request !29329 from mengyuanli/bugfix
This commit is contained in:
commit
74bca04d43
|
@ -25,10 +25,16 @@
|
|||
#include "src/tensorlist.h"
|
||||
#include "src/common/prim_inner.h"
|
||||
|
||||
namespace {
|
||||
const constexpr int kMinNonTailCallCount = 2;
|
||||
}
|
||||
|
||||
namespace mindspore::lite {
|
||||
int ControlFlowScheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
||||
auto ret = this->IsolateSameInputPartials(dst_kernels);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateSameInputPartials failed.");
|
||||
ret = this->IsolateOutputForCallOutputGraph(dst_kernels);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateOutputForCallOutputGraph failed");
|
||||
ret = this->IsolateInputOfMultipleCalledGraph(dst_kernels);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateInputOfMultipleCalledGraph failed.");
|
||||
ret = this->BuildBoundaryForMultipleCalledGraph(dst_kernels);
|
||||
|
@ -37,8 +43,6 @@ int ControlFlowScheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernel
|
|||
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "RecordControlFlowLinkInfo failed.");
|
||||
ret = this->RecordAllTailCallLinkInfo(dst_kernels);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "SplitNonTailCallSubGraphs failed");
|
||||
ret = this->IsolateOutputForCallOutputGraph(dst_kernels);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateOutputForCallOutputGraph failed");
|
||||
ret = this->SplitNonTailCallSubGraphs(dst_kernels);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "SplitNonTailCallSubGraphs failed");
|
||||
return ret;
|
||||
|
@ -183,10 +187,26 @@ int ControlFlowScheduler::RecordNonTailCallLinkInfo() {
|
|||
MS_CHECK_TRUE_MSG(!kernels.empty(), RET_ERROR, "partial subgraph kernels empty.");
|
||||
auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(kernels.back());
|
||||
MS_CHECK_TRUE_MSG(subgraph != nullptr, RET_ERROR, "partial node's subgraph kernel is nullptr.");
|
||||
MS_CHECK_TRUE_MSG(subgraph->out_tensors().size() == non_tail_call_output_size, RET_ERROR,
|
||||
"partial inputs and corresponding call outputs size not same.");
|
||||
for (size_t i = 0; i < non_tail_call_output_size; ++i) {
|
||||
context_->SetLinkInfo(subgraph->out_tensors()[i], non_tail_call->out_tensors()[i]);
|
||||
if (kernel::LiteKernelUtil::IsTailCallSubGraph(subgraph)) {
|
||||
std::queue<kernel::LiteKernel *> tail_call_q{};
|
||||
tail_call_q.push(subgraph->out_nodes().front());
|
||||
std::vector<kernel::LiteKernel *> final_graphs{};
|
||||
std::set<kernel::LiteKernel *> reviewed_graphs{};
|
||||
auto ret = GetTailCallFinalSubgraphs(&tail_call_q, &final_graphs, reviewed_graphs);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "GetTailCallFinalSubgraphs failed.");
|
||||
for (auto item : final_graphs) {
|
||||
MS_CHECK_TRUE_MSG(item->out_tensors().size() == non_tail_call_output_size, RET_ERROR,
|
||||
"subgraph outputs and corresponding call outputs size not same.");
|
||||
for (size_t i = 0; i < non_tail_call_output_size; ++i) {
|
||||
context_->SetLinkInfo(item->out_tensors()[i], non_tail_call->out_tensors()[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_CHECK_TRUE_MSG(subgraph->out_tensors().size() == non_tail_call_output_size, RET_ERROR,
|
||||
"partial inputs and corresponding call outputs size not same.");
|
||||
for (size_t i = 0; i < non_tail_call_output_size; ++i) {
|
||||
context_->SetLinkInfo(subgraph->out_tensors()[i], non_tail_call->out_tensors()[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -322,10 +342,15 @@ int ControlFlowScheduler::GetSubGraphsWhichNeedBoundary() {
|
|||
all_call_nodes.push_back(call_node);
|
||||
}
|
||||
|
||||
// all of the caller is tail call, continue
|
||||
if (std::all_of(all_call_nodes.begin(), all_call_nodes.end(),
|
||||
[](kernel::LiteKernel *call_node) { return kernel::LiteKernelUtil::IsTailCall(call_node); })) {
|
||||
MS_LOG(DEBUG) << "graph is output graph and caller is tail call, no need to build boundary.";
|
||||
// non-tail call size less than 2, continue
|
||||
int non_tail_call_size = 0;
|
||||
for (auto call_node : all_call_nodes) {
|
||||
if (kernel::LiteKernelUtil::IsNonTailCall(call_node)) {
|
||||
non_tail_call_size++;
|
||||
}
|
||||
}
|
||||
if (non_tail_call_size < kMinNonTailCallCount) {
|
||||
MS_LOG(DEBUG) << "no need to build boundary.";
|
||||
continue;
|
||||
}
|
||||
for (auto partial_node : item.second) {
|
||||
|
|
Loading…
Reference in New Issue