forked from mindspore-Ecosystem/mindspore
!29256 [MS][LITE]fix bug of control flow
Merge pull request !29256 from mengyuanli/bugfix
This commit is contained in:
commit
73b744e0d6
|
@ -269,6 +269,10 @@ std::vector<mindspore::MSTensor> LiteTensorsToMSTensors(const std::vector<lite::
|
|||
|
||||
void MoveCommonTensorData(Tensor *dst_tensor, Tensor *src_tensor) {
|
||||
MS_ASSERT(src_tensor != dst_tensor);
|
||||
if (src_tensor->data() == dst_tensor->data()) {
|
||||
MS_LOG(DEBUG) << "no need to move data.";
|
||||
return;
|
||||
}
|
||||
dst_tensor->FreeData();
|
||||
dst_tensor->ResetRefCount();
|
||||
dst_tensor->set_allocator(src_tensor->allocator());
|
||||
|
|
|
@ -29,6 +29,8 @@ namespace mindspore::lite {
|
|||
int ControlFlowScheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
||||
auto ret = this->IsolateSameInputPartials(dst_kernels);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateSameInputPartials failed.");
|
||||
ret = this->IsolateInputOfMultipleCalledGraph(dst_kernels);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateInputOfMultipleCalledGraph failed.");
|
||||
ret = this->BuildBoundaryForMultipleCalledGraph(dst_kernels);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "BuildBoundaryForMultipleCalledGraph failed.");
|
||||
ret = this->RecordControlFlowLinkInfo();
|
||||
|
@ -299,7 +301,8 @@ kernel::SubGraphKernel *ControlFlowScheduler::AddOutputKernel(kernel::SubGraphKe
|
|||
return new_subgraph;
|
||||
}
|
||||
|
||||
int ControlFlowScheduler::BuildBoundaryForMultipleCalledGraph(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
||||
int ControlFlowScheduler::GetSubGraphsWhichNeedBoundary() {
|
||||
// among the more than once call subgraphs, if one of it's corresponding partial nodes' call node is non-tail call.
|
||||
for (auto item : more_than_once_called_partial_nodes_) {
|
||||
if (item.second.size() == 1) {
|
||||
MS_LOG(DEBUG) << "subgraph call only once.";
|
||||
|
@ -325,7 +328,16 @@ int ControlFlowScheduler::BuildBoundaryForMultipleCalledGraph(std::vector<kernel
|
|||
MS_LOG(DEBUG) << "graph is output graph and caller is tail call, no need to build boundary.";
|
||||
continue;
|
||||
}
|
||||
for (auto partial_node : item.second) {
|
||||
subgraphs_need_boundary_[subgraph].insert(partial_node);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ControlFlowScheduler::BuildBoundaryForMultipleCalledGraph(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
||||
for (auto &item : subgraphs_need_boundary_) {
|
||||
auto subgraph = item.first;
|
||||
// new link tensor
|
||||
auto link_tensor = new Tensor(kNumberTypeFloat32, {1});
|
||||
if (link_tensor == nullptr) {
|
||||
|
@ -379,7 +391,7 @@ int ControlFlowScheduler::IsolateOutputForCallOutputGraph(std::vector<kernel::Li
|
|||
|
||||
auto subgraph = reinterpret_cast<kernel::SubGraphKernel *>(main_graph_kernel);
|
||||
MS_CHECK_TRUE_MSG(subgraph != nullptr, RET_ERROR, "cast to subgraph failed.");
|
||||
if (subgraph->out_nodes().size() != 1 && subgraph->out_nodes().front()->type() != schema::PrimitiveType_Call) {
|
||||
if (!(subgraph->out_nodes().size() == 1 && subgraph->out_nodes().front()->type() == schema::PrimitiveType_Call)) {
|
||||
MS_LOG(DEBUG) << "main graph output is not call node.";
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -408,16 +420,20 @@ int ControlFlowScheduler::GetTailCallFinalSubgraphs(std::queue<kernel::LiteKerne
|
|||
MS_CHECK_TRUE_MSG(partial_kernel != nullptr, RET_ERROR, "cast to partial kernel failed.");
|
||||
// only get the output subgraph, the last subgraph is the output subgraph.
|
||||
auto subgraphs = partial_kernel->subgraph_kernels();
|
||||
for (auto subgraph : subgraphs) {
|
||||
if (subgraphs.size() > 1) {
|
||||
final_graphs->push_back(subgraphs.back());
|
||||
return RET_OK;
|
||||
} else {
|
||||
auto subgraph = subgraphs.front();
|
||||
auto subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(subgraph);
|
||||
if (kernel::LiteKernelUtil::IsTailCallSubGraph(subgraph_kernel)) {
|
||||
if (reviewed_graphs.find(subgraph) == reviewed_graphs.end()) {
|
||||
tail_call_q->push(subgraph_kernel->out_nodes().front());
|
||||
reviewed_graphs.insert(subgraph);
|
||||
}
|
||||
} else {
|
||||
final_graphs->push_back(subgraph);
|
||||
}
|
||||
reviewed_graphs.insert(subgraph);
|
||||
}
|
||||
}
|
||||
return GetTailCallFinalSubgraphs(tail_call_q, final_graphs, reviewed_graphs);
|
||||
|
@ -561,6 +577,63 @@ int ControlFlowScheduler::IsolateSameInputPartials(std::vector<kernel::LiteKerne
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int ControlFlowScheduler::IsolateInputOfMultipleCalledGraph(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
||||
auto ret = GetSubGraphsWhichNeedBoundary();
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "GetSubGraphsWhichNeedBoundary failed.");
|
||||
std::unordered_map<kernel::SubGraphKernel *, kernel::SubGraphKernel *> replace_pair{};
|
||||
|
||||
for (auto &item : subgraphs_need_boundary_) {
|
||||
auto subgraph = item.first;
|
||||
std::vector<kernel::LiteKernel *> input_partials{};
|
||||
for (auto input : subgraph->in_nodes()) {
|
||||
if (input->op_parameter()->type_ == static_cast<int>(schema::PrimitiveType_PartialFusion)) {
|
||||
input_partials.push_back(input);
|
||||
}
|
||||
}
|
||||
kernel::SubGraphKernel *new_subgraph = nullptr;
|
||||
kernel::SubGraphKernel *cur_subgraph = subgraph;
|
||||
for (auto cur_partial : input_partials) {
|
||||
new_subgraph = IsolatePartialInputs(cur_subgraph, cur_partial);
|
||||
MS_CHECK_TRUE_MSG(new_subgraph != nullptr, RET_ERROR, "create new subgraph failed.");
|
||||
new_subgraph->set_name(cur_subgraph->name());
|
||||
|
||||
cur_subgraph->set_nodes({});
|
||||
delete cur_subgraph;
|
||||
cur_subgraph = new_subgraph;
|
||||
}
|
||||
|
||||
if (new_subgraph != nullptr) {
|
||||
replace_pair[subgraph] = new_subgraph;
|
||||
}
|
||||
}
|
||||
|
||||
// update all partial nodes' subgraph
|
||||
for (auto item : replace_pair) {
|
||||
auto old_subgrpah = item.first;
|
||||
auto new_subgraph = item.second;
|
||||
for (auto partial_node : subgraphs_need_boundary_[old_subgrpah]) {
|
||||
auto partial_kernel = reinterpret_cast<kernel::PartialFusionKernel *>(partial_node->kernel());
|
||||
MS_CHECK_TRUE_MSG(partial_kernel != nullptr, RET_ERROR, "cast to partial kernel failed.");
|
||||
partial_kernel->set_subgraph_kernels({new_subgraph});
|
||||
subgraphs_need_boundary_[new_subgraph].insert(partial_node);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto item : replace_pair) {
|
||||
auto old_subgrpah = item.first;
|
||||
subgraphs_need_boundary_.erase(old_subgrpah);
|
||||
}
|
||||
|
||||
// update all dst_kernels
|
||||
for (auto item : replace_pair) {
|
||||
auto old_subgrpah = item.first;
|
||||
auto new_subgraph = item.second;
|
||||
std::replace(dst_kernels->begin(), dst_kernels->end(), old_subgrpah, new_subgraph);
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void ControlFlowScheduler::SetSubgraphForPartialNode(
|
||||
std::unordered_map<kernel::LiteKernel *, size_t> *partial_kernel_subgraph_index_map,
|
||||
std::unordered_map<size_t, kernel::LiteKernel *> *subgraph_index_subgraph_kernel_map) {
|
||||
|
|
|
@ -56,6 +56,7 @@ class ControlFlowScheduler {
|
|||
int IsolateSameInputPartials(std::vector<kernel::LiteKernel *> *dst_kernels);
|
||||
int RecordAllTailCallLinkInfo(std::vector<kernel::LiteKernel *> *dst_kernels);
|
||||
int RecordControlFlowLinkInfo();
|
||||
int IsolateInputOfMultipleCalledGraph(std::vector<kernel::LiteKernel *> *dst_kernels);
|
||||
|
||||
private:
|
||||
int SplitSingleNonTailCallSubGraph(kernel::SubGraphKernel *subgraph_kernel,
|
||||
|
@ -77,6 +78,7 @@ class ControlFlowScheduler {
|
|||
kernel::SubGraphKernel *IsolatePartialInputs(kernel::SubGraphKernel *subgraph, kernel::LiteKernel *partial);
|
||||
std::set<kernel::LiteKernel *> GetSameInputPartials();
|
||||
void UpdateSubGraphMap(kernel::LiteKernel *new_subgraph, kernel::LiteKernel *old_subgraph);
|
||||
int GetSubGraphsWhichNeedBoundary();
|
||||
|
||||
private:
|
||||
InnerContext *context_ = nullptr;
|
||||
|
@ -87,6 +89,9 @@ class ControlFlowScheduler {
|
|||
std::vector<kernel::LiteKernel *> non_tail_calls_{};
|
||||
// key is subgraph index, value is the corresponding partial nodes.
|
||||
std::unordered_map<size_t, std::set<kernel::LiteKernel *>> more_than_once_called_partial_nodes_{};
|
||||
// record partial nodes which corresponding subgraph need build boundary, key is subgraph, value is corresponding
|
||||
// partial nodes
|
||||
std::unordered_map<kernel::SubGraphKernel *, std::set<kernel::LiteKernel *>> subgraphs_need_boundary_{};
|
||||
std::unordered_map<size_t, kernel::LiteKernel *> *subgraph_index_subgraph_kernel_map_{};
|
||||
std::unordered_map<kernel::LiteKernel *, size_t> *partial_kernel_subgraph_index_map_{};
|
||||
};
|
||||
|
|
|
@ -247,7 +247,7 @@ STATUS DeleteRedundantTrans(std::vector<kernel::LiteKernel *> *kernels) {
|
|||
auto pre_kernel_in_tensor_shape = pre_kernel->in_tensors().at(0)->shape();
|
||||
auto pre_kernel_out_tensor_shape = pre_kernel->out_tensors().at(0)->shape();
|
||||
for (size_t i = 0; i < pre_kernel_out_tensor_shape.size(); i++) {
|
||||
if (pre_kernel_in_tensor_shape[i] == -1) {
|
||||
if (pre_kernel_out_tensor_shape[i] == -1) {
|
||||
MS_LOG(DEBUG) << " input need do resize.";
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue