diff --git a/mindspore/lite/src/common/tensor_util.cc b/mindspore/lite/src/common/tensor_util.cc index 7a102a75d55..61299e11aa8 100644 --- a/mindspore/lite/src/common/tensor_util.cc +++ b/mindspore/lite/src/common/tensor_util.cc @@ -269,6 +269,10 @@ std::vector LiteTensorsToMSTensors(const std::vectordata() == dst_tensor->data()) { + MS_LOG(DEBUG) << "no need to move data."; + return; + } dst_tensor->FreeData(); dst_tensor->ResetRefCount(); dst_tensor->set_allocator(src_tensor->allocator()); diff --git a/mindspore/lite/src/control_flow/control_flow_scheduler.cc b/mindspore/lite/src/control_flow/control_flow_scheduler.cc index 5d41e4cb78d..926d11a86c9 100644 --- a/mindspore/lite/src/control_flow/control_flow_scheduler.cc +++ b/mindspore/lite/src/control_flow/control_flow_scheduler.cc @@ -29,6 +29,8 @@ namespace mindspore::lite { int ControlFlowScheduler::Schedule(std::vector *dst_kernels) { auto ret = this->IsolateSameInputPartials(dst_kernels); MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateSameInputPartials failed."); + ret = this->IsolateInputOfMultipleCalledGraph(dst_kernels); + MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "IsolateInputOfMultipleCalledGraph failed."); ret = this->BuildBoundaryForMultipleCalledGraph(dst_kernels); MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "BuildBoundaryForMultipleCalledGraph failed."); ret = this->RecordControlFlowLinkInfo(); @@ -299,7 +301,8 @@ kernel::SubGraphKernel *ControlFlowScheduler::AddOutputKernel(kernel::SubGraphKe return new_subgraph; } -int ControlFlowScheduler::BuildBoundaryForMultipleCalledGraph(std::vector *dst_kernels) { +int ControlFlowScheduler::GetSubGraphsWhichNeedBoundary() { + // among the more than once call subgraphs, if one of it's corresponding partial nodes' call node is non-tail call. for (auto item : more_than_once_called_partial_nodes_) { if (item.second.size() == 1) { MS_LOG(DEBUG) << "subgraph call only once."; @@ -325,7 +328,16 @@ int ControlFlowScheduler::BuildBoundaryForMultipleCalledGraph(std::vector *dst_kernels) { + for (auto &item : subgraphs_need_boundary_) { + auto subgraph = item.first; // new link tensor auto link_tensor = new Tensor(kNumberTypeFloat32, {1}); if (link_tensor == nullptr) { @@ -379,7 +391,7 @@ int ControlFlowScheduler::IsolateOutputForCallOutputGraph(std::vector(main_graph_kernel); MS_CHECK_TRUE_MSG(subgraph != nullptr, RET_ERROR, "cast to subgraph failed."); - if (subgraph->out_nodes().size() != 1 && subgraph->out_nodes().front()->type() != schema::PrimitiveType_Call) { + if (!(subgraph->out_nodes().size() == 1 && subgraph->out_nodes().front()->type() == schema::PrimitiveType_Call)) { MS_LOG(DEBUG) << "main graph output is not call node."; return RET_OK; } @@ -408,16 +420,20 @@ int ControlFlowScheduler::GetTailCallFinalSubgraphs(std::queuesubgraph_kernels(); - for (auto subgraph : subgraphs) { + if (subgraphs.size() > 1) { + final_graphs->push_back(subgraphs.back()); + return RET_OK; + } else { + auto subgraph = subgraphs.front(); auto subgraph_kernel = reinterpret_cast(subgraph); if (kernel::LiteKernelUtil::IsTailCallSubGraph(subgraph_kernel)) { if (reviewed_graphs.find(subgraph) == reviewed_graphs.end()) { tail_call_q->push(subgraph_kernel->out_nodes().front()); - reviewed_graphs.insert(subgraph); } } else { final_graphs->push_back(subgraph); } + reviewed_graphs.insert(subgraph); } } return GetTailCallFinalSubgraphs(tail_call_q, final_graphs, reviewed_graphs); @@ -561,6 +577,63 @@ int ControlFlowScheduler::IsolateSameInputPartials(std::vector *dst_kernels) { + auto ret = GetSubGraphsWhichNeedBoundary(); + MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "GetSubGraphsWhichNeedBoundary failed."); + std::unordered_map replace_pair{}; + + for (auto &item : subgraphs_need_boundary_) { + auto subgraph = item.first; + std::vector input_partials{}; + for (auto input : subgraph->in_nodes()) { + if (input->op_parameter()->type_ == static_cast(schema::PrimitiveType_PartialFusion)) { + input_partials.push_back(input); + } + } + kernel::SubGraphKernel *new_subgraph = nullptr; + kernel::SubGraphKernel *cur_subgraph = subgraph; + for (auto cur_partial : input_partials) { + new_subgraph = IsolatePartialInputs(cur_subgraph, cur_partial); + MS_CHECK_TRUE_MSG(new_subgraph != nullptr, RET_ERROR, "create new subgraph failed."); + new_subgraph->set_name(cur_subgraph->name()); + + cur_subgraph->set_nodes({}); + delete cur_subgraph; + cur_subgraph = new_subgraph; + } + + if (new_subgraph != nullptr) { + replace_pair[subgraph] = new_subgraph; + } + } + + // update all partial nodes' subgraph + for (auto item : replace_pair) { + auto old_subgrpah = item.first; + auto new_subgraph = item.second; + for (auto partial_node : subgraphs_need_boundary_[old_subgrpah]) { + auto partial_kernel = reinterpret_cast(partial_node->kernel()); + MS_CHECK_TRUE_MSG(partial_kernel != nullptr, RET_ERROR, "cast to partial kernel failed."); + partial_kernel->set_subgraph_kernels({new_subgraph}); + subgraphs_need_boundary_[new_subgraph].insert(partial_node); + } + } + + for (auto item : replace_pair) { + auto old_subgrpah = item.first; + subgraphs_need_boundary_.erase(old_subgrpah); + } + + // update all dst_kernels + for (auto item : replace_pair) { + auto old_subgrpah = item.first; + auto new_subgraph = item.second; + std::replace(dst_kernels->begin(), dst_kernels->end(), old_subgrpah, new_subgraph); + } + + return RET_OK; +} + void ControlFlowScheduler::SetSubgraphForPartialNode( std::unordered_map *partial_kernel_subgraph_index_map, std::unordered_map *subgraph_index_subgraph_kernel_map) { diff --git a/mindspore/lite/src/control_flow/control_flow_scheduler.h b/mindspore/lite/src/control_flow/control_flow_scheduler.h index a90c49ecb23..9ccf2d5ae05 100644 --- a/mindspore/lite/src/control_flow/control_flow_scheduler.h +++ b/mindspore/lite/src/control_flow/control_flow_scheduler.h @@ -56,6 +56,7 @@ class ControlFlowScheduler { int IsolateSameInputPartials(std::vector *dst_kernels); int RecordAllTailCallLinkInfo(std::vector *dst_kernels); int RecordControlFlowLinkInfo(); + int IsolateInputOfMultipleCalledGraph(std::vector *dst_kernels); private: int SplitSingleNonTailCallSubGraph(kernel::SubGraphKernel *subgraph_kernel, @@ -77,6 +78,7 @@ class ControlFlowScheduler { kernel::SubGraphKernel *IsolatePartialInputs(kernel::SubGraphKernel *subgraph, kernel::LiteKernel *partial); std::set GetSameInputPartials(); void UpdateSubGraphMap(kernel::LiteKernel *new_subgraph, kernel::LiteKernel *old_subgraph); + int GetSubGraphsWhichNeedBoundary(); private: InnerContext *context_ = nullptr; @@ -87,6 +89,9 @@ class ControlFlowScheduler { std::vector non_tail_calls_{}; // key is subgraph index, value is the corresponding partial nodes. std::unordered_map> more_than_once_called_partial_nodes_{}; + // record partial nodes which corresponding subgraph need build boundary, key is subgraph, value is corresponding + // partial nodes + std::unordered_map> subgraphs_need_boundary_{}; std::unordered_map *subgraph_index_subgraph_kernel_map_{}; std::unordered_map *partial_kernel_subgraph_index_map_{}; }; diff --git a/mindspore/lite/src/runtime/runtime_pass.cc b/mindspore/lite/src/runtime/runtime_pass.cc index d43465f9a61..41fb244ffa5 100644 --- a/mindspore/lite/src/runtime/runtime_pass.cc +++ b/mindspore/lite/src/runtime/runtime_pass.cc @@ -247,7 +247,7 @@ STATUS DeleteRedundantTrans(std::vector *kernels) { auto pre_kernel_in_tensor_shape = pre_kernel->in_tensors().at(0)->shape(); auto pre_kernel_out_tensor_shape = pre_kernel->out_tensors().at(0)->shape(); for (size_t i = 0; i < pre_kernel_out_tensor_shape.size(); i++) { - if (pre_kernel_in_tensor_shape[i] == -1) { + if (pre_kernel_out_tensor_shape[i] == -1) { MS_LOG(DEBUG) << " input need do resize."; return RET_OK; }