From c9c94b7340d96c51ef54210ae11a48b194f82ba3 Mon Sep 17 00:00:00 2001 From: mengyuanli Date: Fri, 29 Oct 2021 09:34:23 +0800 Subject: [PATCH] move call inputs to partial op intpus --- mindspore/lite/src/lite_session.cc | 2 +- mindspore/lite/src/runtime/infer_manager.cc | 2 +- mindspore/lite/tools/common/node_util.cc | 10 ++ mindspore/lite/tools/common/node_util.h | 2 + .../import/mindir_control_flow_adjust.cc | 107 ++++++++++++++++-- .../import/mindir_control_flow_adjust.h | 1 + 6 files changed, 113 insertions(+), 11 deletions(-) diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index f4c05645bb8..ab0cda3cec7 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -422,7 +422,7 @@ int LiteSession::IsolateOutputTensor() { Tensor *new_tensor = new Tensor(src_tensor->data_type(), src_tensor->shape(), src_tensor->format(), Tensor::GRAPH_OUTPUT); if (new_tensor == nullptr) { - MS_LOG(ERROR) << "duplicate new outptu failed."; + MS_LOG(ERROR) << "duplicate new output failed."; return RET_NULL_PTR; } new_tensor->set_allocator(src_tensor->allocator()); /* GPU use opencl allocator */ diff --git a/mindspore/lite/src/runtime/infer_manager.cc b/mindspore/lite/src/runtime/infer_manager.cc index 3fd0f10a727..1c0ec935953 100644 --- a/mindspore/lite/src/runtime/infer_manager.cc +++ b/mindspore/lite/src/runtime/infer_manager.cc @@ -122,7 +122,7 @@ int KernelInferShape(const std::vector &inputs, const std::vecto std::vector in_tensors; std::vector out_tensors; if (parameter->type_ == schema::PrimitiveType_PartialFusion || parameter->type_ == schema::PrimitiveType_Switch || - parameter->type_ == schema::PrimitiveType_Call) { + parameter->type_ == schema::PrimitiveType_Call || parameter->type_ == schema::PrimitiveType_SwitchLayer) { MS_LOG(INFO) << "no need infer shape."; return RET_OK; } diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index c42af780b7e..1de8a5a8079 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -392,6 +392,16 @@ bool IsSwitch(const AnfNodePtr &node) { return opt::CheckPrimitiveType(node, prim::kPrimSwitch); } +bool IsSwitchLayer(const AnfNodePtr &node) { + if (node == nullptr) { + return false; + } + if (!utils::isa(node)) { + return false; + } + return opt::CheckPrimitiveType(node, prim::kPrimSwitchLayer); +} + bool IsMakeTuple(const AnfNodePtr &node) { if (node == nullptr) { return false; diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h index da79f118145..f68ccadbf83 100644 --- a/mindspore/lite/tools/common/node_util.h +++ b/mindspore/lite/tools/common/node_util.h @@ -419,6 +419,8 @@ bool IsCall(const AnfNodePtr &node); bool IsSwitch(const AnfNodePtr &node); +bool IsSwitchLayer(const AnfNodePtr &node); + bool IsMakeTuple(const AnfNodePtr &node); ValueNodePtr GetPartialFusionPrim(); diff --git a/mindspore/lite/tools/converter/import/mindir_control_flow_adjust.cc b/mindspore/lite/tools/converter/import/mindir_control_flow_adjust.cc index ac465c0d56d..9880ce6b891 100644 --- a/mindspore/lite/tools/converter/import/mindir_control_flow_adjust.cc +++ b/mindspore/lite/tools/converter/import/mindir_control_flow_adjust.cc @@ -26,13 +26,16 @@ #include "tools/converter/parser/parser_utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "nnacl/op_base.h" +namespace { +constexpr const int kSwitchTruePartialIndex = 2; +constexpr const int kSwitchFalsePartialIndex = 3; +constexpr const int kSwitchInputSize = 4; +constexpr const int kSwitchLayerInputSize = 3; +constexpr const int kSwitchLayerMakeTupleIndex = 2; +} // namespace namespace mindspore { namespace lite { -constexpr const int kSwitchTruePartialIndex = 2; -constexpr const int kSwitchFalsePartialIndex = 3; -constexpr const int kPartialFgVnodeIndex = 1; - bool MindIRControlFlowAdjust::HasCallAfter(const FuncGraphPtr &partial_fg) { MS_CHECK_TRUE_MSG(partial_fg != nullptr, false, "partial_fg is nullptr."); auto output_node = partial_fg->output(); @@ -229,6 +232,88 @@ int MindIRControlFlowAdjust::AddAfterFgForInlinedFg(const std::set return RET_OK; } +int MindIRControlFlowAdjust::MoveCallInputsToPartialFusionInputs(const std::set &all_func_graphs) { + for (auto &graph : all_func_graphs) { + auto node_list = TopoSort(graph->get_return()); + for (auto &node : node_list) { + if (!IsCall(node)) { + continue; + } + auto call_cnode = node->cast(); + MS_ASSERT(call_node != nullptr); + auto call_cnode_inputs = call_cnode->inputs(); + if (call_cnode_inputs.size() == 1) { + MS_LOG(DEBUG) << "no need move call inputs."; + continue; + } + auto call_first_input = call_cnode->input(0); + if (!utils::isa(call_first_input)) { + // This situation will be handled in the InsertPartialFusionForRawCall function + continue; + } + auto call_first_input_cnode = call_first_input->cast(); + MS_ASSERT(call_first_input_cnode != nullptr); + if (IsPartialFusion(call_first_input_cnode)) { + auto partial_cnode_inputs = call_first_input_cnode->inputs(); + std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(), std::back_inserter(partial_cnode_inputs)); + call_first_input_cnode->set_inputs(partial_cnode_inputs); + } + + if (IsSwitch(call_first_input_cnode)) { + auto switch_cnode_inputs = call_first_input_cnode->inputs(); + if (switch_cnode_inputs.size() == kSwitchInputSize) { + MS_LOG(ERROR) << "switch op inputs size not right."; + return RET_ERROR; + } + if (!IsPartialFusion(switch_cnode_inputs[kSwitchTruePartialIndex]) || + !IsPartialFusion(switch_cnode_inputs[kSwitchFalsePartialIndex])) { + MS_LOG(ERROR) << "switch inputs not are partial ops, not support now."; + return RET_NOT_SUPPORT; + } + + auto true_partial_cnode = switch_cnode_inputs.at(kSwitchTruePartialIndex)->cast(); + auto true_partial_cnode_inputs = true_partial_cnode->inputs(); + std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(), + std::back_inserter(true_partial_cnode_inputs)); + true_partial_cnode->set_inputs(true_partial_cnode_inputs); + + auto false_partial_cnode = switch_cnode_inputs.at(kSwitchFalsePartialIndex)->cast(); + auto false_partial_cnode_inputs = false_partial_cnode->inputs(); + std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(), + std::back_inserter(false_partial_cnode_inputs)); + false_partial_cnode->set_inputs(false_partial_cnode_inputs); + } + + if (IsSwitchLayer(call_first_input_cnode)) { + auto switch_layer_cnode_inputs = call_first_input_cnode->inputs(); + if (switch_layer_cnode_inputs.size() != kSwitchLayerInputSize) { + MS_LOG(ERROR) << "switch layer op inputs size not right."; + return RET_ERROR; + } + if (!IsMakeTuple(switch_layer_cnode_inputs[kSwitchLayerMakeTupleIndex])) { + MS_LOG(ERROR) << "SwitchLayer op last input not is MakeTuple ops, not support now."; + return RET_NOT_SUPPORT; + } + auto make_tuple_op = switch_layer_cnode_inputs[kSwitchLayerMakeTupleIndex]->cast(); + auto make_tuple_op_intpus = make_tuple_op->inputs(); + for (size_t i = 1; i < make_tuple_op_intpus.size(); i++) { + if (!IsPartialFusion(make_tuple_op_intpus[i])) { + MS_LOG(ERROR) << "switch layer op make tuple inputs not is partial fusion op, not support now."; + return RET_NOT_SUPPORT; + } + auto partial_node = make_tuple_op_intpus[i]->cast(); + auto partial_node_inputs = partial_node->inputs(); + std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(), std::back_inserter(partial_node_inputs)); + partial_node->set_inputs(partial_node_inputs); + } + } + + call_cnode->set_inputs({call_first_input_cnode}); + } + } + return RET_OK; +} + int MindIRControlFlowAdjust::InsertPartialFusionForRawCall(const std::set &all_func_graphs) { for (auto &graph : all_func_graphs) { auto node_list = TopoSort(graph->get_return()); @@ -239,14 +324,13 @@ int MindIRControlFlowAdjust::InsertPartialFusionForRawCall(const std::setcast(); MS_ASSERT(call_node != nullptr); auto call_cnode_inputs = call_cnode->inputs(); - auto cnode_first_input = call_cnode->input(0); - if (!utils::isa(cnode_first_input)) { + auto call_first_input = call_cnode->input(0); + if (!utils::isa(call_first_input)) { continue; } - if (GetValueNode(cnode_first_input->cast()) == nullptr) { + if (GetValueNode(call_first_input->cast()) == nullptr) { continue; } - std::vector partial_cnode_inputs = {lite::GetPartialFusionPrim()}; std::copy(call_cnode_inputs.begin(), call_cnode_inputs.end(), std::back_inserter(partial_cnode_inputs)); auto partial_cnode = graph->NewCNode(partial_cnode_inputs); @@ -285,7 +369,12 @@ bool MindIRControlFlowAdjust::Run(const FuncGraphPtr &func_graph) { MS_LOG(INFO) << "Not is control flow model."; return true; } - int ret = InsertPartialFusionForRawCall(all_func_graphs); + int ret = MoveCallInputsToPartialFusionInputs(all_func_graphs); + if (ret != RET_OK) { + MS_LOG(ERROR) << "MoveCallInputsToPartialFusionInputs failed."; + return false; + } + ret = InsertPartialFusionForRawCall(all_func_graphs); if (ret != RET_OK) { MS_LOG(ERROR) << "InsertPartialFusionForRawCall failed."; return false; diff --git a/mindspore/lite/tools/converter/import/mindir_control_flow_adjust.h b/mindspore/lite/tools/converter/import/mindir_control_flow_adjust.h index 20403e821ba..ae589d51767 100644 --- a/mindspore/lite/tools/converter/import/mindir_control_flow_adjust.h +++ b/mindspore/lite/tools/converter/import/mindir_control_flow_adjust.h @@ -43,6 +43,7 @@ class MindIRControlFlowAdjust { int InsertPartialFusionForRawCall(const std::set &all_func_graphs); CNodePtr GetMainFgSwitchNode(const FuncGraphPtr &fg); int ResetFuncGraph(const FuncGraphPtr &fg, std::set all_func_graphs); + int MoveCallInputsToPartialFusionInputs(const std::set &all_func_graphs); private: FmkType fmk_type_ = FmkType::kFmkTypeMs;