!29256 [MS][LITE]fix bug of control flow

Merge pull request !29256 from mengyuanli/bugfix
This commit is contained in:
i-robot 2022-01-20 01:15:44 +00:00 committed by Gitee
commit 73b744e0d6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 87 additions and 5 deletions

View File

@ -269,6 +269,10 @@ std::vector<mindspore::MSTensor> LiteTensorsToMSTensors(const std::vector<lite::
void MoveCommonTensorData(Tensor *dst_tensor, Tensor *src_tensor) {
MS_ASSERT(src_tensor != dst_tensor);
if (src_tensor->data() == dst_tensor->data()) {
MS_LOG(DEBUG) << "no need to move data.";
return;
}
dst_tensor->FreeData();
dst_tensor->ResetRefCount();
dst_tensor->set_allocator(src_tensor->allocator());

View File

@ -29,6 +29,8 @@ namespace mindspore::lite {
int ControlFlowScheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
auto ret = this->IsolateSameInputPartials(dst_kernels);
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateSameInputPartials failed.");
ret = this->IsolateInputOfMultipleCalledGraph(dst_kernels);
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateInputOfMultipleCalledGraph failed.");
ret = this->BuildBoundaryForMultipleCalledGraph(dst_kernels);
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "BuildBoundaryForMultipleCalledGraph failed.");
ret = this->RecordControlFlowLinkInfo();
@ -299,7 +301,8 @@ kernel::SubGraphKernel *ControlFlowScheduler::AddOutputKernel(kernel::SubGraphKe
return new_subgraph;
}
int ControlFlowScheduler::BuildBoundaryForMultipleCalledGraph(std::vector<kernel::LiteKernel *> *dst_kernels) {
int ControlFlowScheduler::GetSubGraphsWhichNeedBoundary() {
// among the more than once call subgraphs, if one of it's corresponding partial nodes' call node is non-tail call.
for (auto item : more_than_once_called_partial_nodes_) {
if (item.second.size() == 1) {
MS_LOG(DEBUG) << "subgraph call only once.";
@ -325,7 +328,16 @@ int ControlFlowScheduler::BuildBoundaryForMultipleCalledGraph(std::vector<kernel
MS_LOG(DEBUG) << "graph is output graph and caller is tail call, no need to build boundary.";
continue;
}
for (auto partial_node : item.second) {
subgraphs_need_boundary_[subgraph].insert(partial_node);
}
}
return RET_OK;
}
int ControlFlowScheduler::BuildBoundaryForMultipleCalledGraph(std::vector<kernel::LiteKernel *> *dst_kernels) {
for (auto &item : subgraphs_need_boundary_) {
auto subgraph = item.first;
// new link tensor
auto link_tensor = new Tensor(kNumberTypeFloat32, {1});
if (link_tensor == nullptr) {
@ -379,7 +391,7 @@ int ControlFlowScheduler::IsolateOutputForCallOutputGraph(std::vector<kernel::Li
auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(main_graph_kernel);
MS_CHECK_TRUE_MSG(subgraph != nullptr, RET_ERROR, "cast to subgraph failed.");
if (subgraph->out_nodes().size() != 1 && subgraph->out_nodes().front()->type() != schema::PrimitiveType_Call) {
if (!(subgraph->out_nodes().size() == 1 && subgraph->out_nodes().front()->type() == schema::PrimitiveType_Call)) {
MS_LOG(DEBUG) << "main graph output is not call node.";
return RET_OK;
}
@ -408,16 +420,20 @@ int ControlFlowScheduler::GetTailCallFinalSubgraphs(std::queue<kernel::LiteKerne
MS_CHECK_TRUE_MSG(partial_kernel != nullptr, RET_ERROR, "cast to partial kernel failed.");
// only get the output subgraph, the last subgraph is the output subgraph.
auto subgraphs = partial_kernel->subgraph_kernels();
for (auto subgraph : subgraphs) {
if (subgraphs.size() > 1) {
final_graphs->push_back(subgraphs.back());
return RET_OK;
} else {
auto subgraph = subgraphs.front();
auto subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(subgraph);
if (kernel::LiteKernelUtil::IsTailCallSubGraph(subgraph_kernel)) {
if (reviewed_graphs.find(subgraph) == reviewed_graphs.end()) {
tail_call_q->push(subgraph_kernel->out_nodes().front());
reviewed_graphs.insert(subgraph);
}
} else {
final_graphs->push_back(subgraph);
}
reviewed_graphs.insert(subgraph);
}
}
return GetTailCallFinalSubgraphs(tail_call_q, final_graphs, reviewed_graphs);
@ -561,6 +577,63 @@ int ControlFlowScheduler::IsolateSameInputPartials(std::vector<kernel::LiteKerne
return RET_OK;
}
int ControlFlowScheduler::IsolateInputOfMultipleCalledGraph(std::vector<kernel::LiteKernel *> *dst_kernels) {
auto ret = GetSubGraphsWhichNeedBoundary();
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "GetSubGraphsWhichNeedBoundary failed.");
std::unordered_map<kernel::SubGraphKernel *, kernel::SubGraphKernel *> replace_pair{};
for (auto &item : subgraphs_need_boundary_) {
auto subgraph = item.first;
std::vector<kernel::LiteKernel *> input_partials{};
for (auto input : subgraph->in_nodes()) {
if (input->op_parameter()->type_ == static_cast<int>(schema::PrimitiveType_PartialFusion)) {
input_partials.push_back(input);
}
}
kernel::SubGraphKernel *new_subgraph = nullptr;
kernel::SubGraphKernel *cur_subgraph = subgraph;
for (auto cur_partial : input_partials) {
new_subgraph = IsolatePartialInputs(cur_subgraph, cur_partial);
MS_CHECK_TRUE_MSG(new_subgraph != nullptr, RET_ERROR, "create new subgraph failed.");
new_subgraph->set_name(cur_subgraph->name());
cur_subgraph->set_nodes({});
delete cur_subgraph;
cur_subgraph = new_subgraph;
}
if (new_subgraph != nullptr) {
replace_pair[subgraph] = new_subgraph;
}
}
// update all partial nodes' subgraph
for (auto item : replace_pair) {
auto old_subgrpah = item.first;
auto new_subgraph = item.second;
for (auto partial_node : subgraphs_need_boundary_[old_subgrpah]) {
auto partial_kernel = reinterpret_cast<kernel::PartialFusionKernel *>(partial_node->kernel());
MS_CHECK_TRUE_MSG(partial_kernel != nullptr, RET_ERROR, "cast to partial kernel failed.");
partial_kernel->set_subgraph_kernels({new_subgraph});
subgraphs_need_boundary_[new_subgraph].insert(partial_node);
}
}
for (auto item : replace_pair) {
auto old_subgrpah = item.first;
subgraphs_need_boundary_.erase(old_subgrpah);
}
// update all dst_kernels
for (auto item : replace_pair) {
auto old_subgrpah = item.first;
auto new_subgraph = item.second;
std::replace(dst_kernels->begin(), dst_kernels->end(), old_subgrpah, new_subgraph);
}
return RET_OK;
}
void ControlFlowScheduler::SetSubgraphForPartialNode(
std::unordered_map<kernel::LiteKernel *, size_t> *partial_kernel_subgraph_index_map,
std::unordered_map<size_t, kernel::LiteKernel *> *subgraph_index_subgraph_kernel_map) {

View File

@ -56,6 +56,7 @@ class ControlFlowScheduler {
int IsolateSameInputPartials(std::vector<kernel::LiteKernel *> *dst_kernels);
int RecordAllTailCallLinkInfo(std::vector<kernel::LiteKernel *> *dst_kernels);
int RecordControlFlowLinkInfo();
int IsolateInputOfMultipleCalledGraph(std::vector<kernel::LiteKernel *> *dst_kernels);
private:
int SplitSingleNonTailCallSubGraph(kernel::SubGraphKernel *subgraph_kernel,
@ -77,6 +78,7 @@ class ControlFlowScheduler {
kernel::SubGraphKernel *IsolatePartialInputs(kernel::SubGraphKernel *subgraph, kernel::LiteKernel *partial);
std::set<kernel::LiteKernel *> GetSameInputPartials();
void UpdateSubGraphMap(kernel::LiteKernel *new_subgraph, kernel::LiteKernel *old_subgraph);
int GetSubGraphsWhichNeedBoundary();
private:
InnerContext *context_ = nullptr;
@ -87,6 +89,9 @@ class ControlFlowScheduler {
std::vector<kernel::LiteKernel *> non_tail_calls_{};
// key is subgraph index, value is the corresponding partial nodes.
std::unordered_map<size_t, std::set<kernel::LiteKernel *>> more_than_once_called_partial_nodes_{};
// record partial nodes which corresponding subgraph need build boundary, key is subgraph, value is corresponding
// partial nodes
std::unordered_map<kernel::SubGraphKernel *, std::set<kernel::LiteKernel *>> subgraphs_need_boundary_{};
std::unordered_map<size_t, kernel::LiteKernel *> *subgraph_index_subgraph_kernel_map_{};
std::unordered_map<kernel::LiteKernel *, size_t> *partial_kernel_subgraph_index_map_{};
};

View File

@ -247,7 +247,7 @@ STATUS DeleteRedundantTrans(std::vector<kernel::LiteKernel *> *kernels) {
auto pre_kernel_in_tensor_shape = pre_kernel->in_tensors().at(0)->shape();
auto pre_kernel_out_tensor_shape = pre_kernel->out_tensors().at(0)->shape();
for (size_t i = 0; i < pre_kernel_out_tensor_shape.size(); i++) {
if (pre_kernel_in_tensor_shape[i] == -1) {
if (pre_kernel_out_tensor_shape[i] == -1) {
MS_LOG(DEBUG) << " input need do resize.";
return RET_OK;
}