diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/transpose_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/transpose_infer.c index ebb6b326b35..008e8d8a0e8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/transpose_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/transpose_infer.c @@ -27,20 +27,6 @@ bool CheckPermTransFormat(const int *perm, const int *perm_transformat, const si } int SetOutputShape(int perms_num, const TensorC *input, TensorC *output, int *perm, size_t perm_size, int *out_shape) { - if (perms_num == 4) { - const int nchw2nhwc[4] = {0, 2, 3, 1}; - const int nhwc2nchw[4] = {0, 3, 1, 2}; - const int trans3d[3] = {0, 2, 1}; - if (input->format_ == Format_NCHW && CheckPermTransFormat(perm, nchw2nhwc, perms_num)) { - output->format_ = Format_NHWC; - } else if (input->format_ == Format_NHWC && CheckPermTransFormat(perm, nhwc2nchw, perms_num)) { - output->format_ = Format_NCHW; - } - // though the perm is 4d in default, the input can be a 3d tensor. The op implementation should be adapted to this. - if (input->shape_size_ == 3) { - ShapeSet(perm, &perm_size, trans3d, 3); - } - } // set output shape size_t in_shape_size = input->shape_size_; output->shape_size_ = in_shape_size; @@ -76,13 +62,6 @@ int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor TensorC *output = outputs[0]; SetDataTypeFormat(output, input); - if (parameter->quant_type_ == QuantType_QUANT_WEIGHT) { - output->data_type_ = kNumberTypeFloat32; - } - if (!InferFlag(inputs, inputs_size)) { - return NNACL_INFER_INVALID; - } - const TensorC *perm_tensor = inputs[1]; const int32_t *perm_data = (int32_t *)perm_tensor->data_; const size_t perms_num = (size_t)perm_tensor->shape_[0]; @@ -97,6 +76,28 @@ int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor for (size_t i = 0; i < perms_num; i++) { ShapePush(perm, &perm_size, perm_data[i]); } + if (perms_num == 4) { + const int nchw2nhwc[4] = {0, 2, 3, 1}; + const int nhwc2nchw[4] = {0, 3, 1, 2}; + const int trans3d[3] = {0, 2, 1}; + if (input->format_ == Format_NCHW && CheckPermTransFormat(perm, nchw2nhwc, perms_num)) { + output->format_ = Format_NHWC; + } else if ((input->format_ == Format_NHWC || input->format_ == Format_KHWC) && + CheckPermTransFormat(perm, nhwc2nchw, perms_num)) { + output->format_ = Format_NCHW; + } + // though the perm is 4d in default, the input can be a 3d tensor. The op implementation should be adapted to this. + if (input->shape_size_ == 3) { + ShapeSet(perm, &perm_size, trans3d, 3); + } + } + if (parameter->quant_type_ == QuantType_QUANT_WEIGHT) { + output->data_type_ = kNumberTypeFloat32; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + // set output shape int out_shape[MAX_TRANSPOSE_DIM_SIZE] = {0}; SetOutputShape(perms_num, input, output, perm, perm_size, out_shape); diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index df09670192f..be73b4991fe 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -247,6 +247,7 @@ if(MSLITE_ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/squeeze_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/transpose_fusion.cc ${LITE_DIR}/tools/optimizer/graph/add_tensor_array.cc ${LITE_DIR}/tools/optimizer/graph/conv1d_weight_expanding_pass.cc ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc @@ -282,6 +283,7 @@ if(MSLITE_ENABLE_CONVERTER) ${LITE_DIR}/tools/common/node_util.cc ${LITE_DIR}/tools/common/storage.cc ${LITE_DIR}/tools/converter/parser/inputs_adjust.cc + ${LITE_DIR}/tools/converter/parser/insert_transpose.cc ${LITE_DIR}/tools/converter/parser/unused_node_remove_pass.cc ${LITE_DIR}/tools/converter/parser/conv1d_inout_adjust.cc ${LITE_DIR}/tools/converter/parser/tf_bidirection_gru_cf_fusion.cc diff --git a/mindspore/lite/tools/anf_exporter/fetch_content.cc b/mindspore/lite/tools/anf_exporter/fetch_content.cc index 2dd9cfb7ff5..40d93a0c1c4 100644 --- a/mindspore/lite/tools/anf_exporter/fetch_content.cc +++ b/mindspore/lite/tools/anf_exporter/fetch_content.cc @@ -22,6 +22,7 @@ #include "tools/converter/quant_param_holder.h" #include "tools/optimizer/common/gllo_utils.h" #include "utils/check_convert_utils.h" +#include "tools/optimizer/common/format_utils.h" namespace mindspore { namespace lite { @@ -284,12 +285,15 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F return RET_ERROR; } auto prim = GetValueNode(cnode->input(0)); - if (prim->GetAttr(ops::kFormat) != nullptr) { + if (prim->GetAttr(ops::kFormat) != nullptr && !opt::CheckPrimitiveType(cnode, prim::kPrimResize)) { auto value = prim->GetAttr(ops::kFormat); if (value->isa()) { data_info->format_ = GetValue(value); } } + if (!param_node->has_default()) { + data_info->format_ = NHWC; + } // attr weightFormat is only used by conv-like ops' second input if ((opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) || @@ -347,6 +351,31 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy return ret; } +int SetFormatForCnode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag, + DataInfo *data_info) { + data_info->format_ = mindspore::NHWC; + auto input_node_prim = GetValueNode((cnode->input(index)->cast()->input(0))); + if (input_node_prim->GetAttr(ops::kFormat) != nullptr) { + auto value = input_node_prim->GetAttr(ops::kFormat); + if (value->isa()) { + data_info->format_ = GetValue(value); + } + } + if (opt::CheckPrimitiveType(cnode->input(index), prim::kPrimTranspose)) { + std::vector perm; + if (opt::GetTransposePerm(cnode->input(index)->cast(), &perm) != RET_OK) { + return RET_ERROR; + } + if (perm[0] == 0 && perm[1] == 3 && perm[2] == 1 && perm[3] == 2 && + (data_info->format_ == NHWC || data_info->format_ == KHWC)) { + data_info->format_ = NCHW; + } else if (perm[0] == 0 && perm[1] == 2 && perm[2] == 3 && perm[3] == 1 && data_info->format_ == NCHW) { + data_info->format_ = NHWC; + } + } + return RET_OK; +} + int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag, DataInfo *data_info) { MS_ASSERT(cnode != nullptr && data_info != nullptr); @@ -368,7 +397,11 @@ int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType f } auto shape_vector = utils::cast(abstract_tensor->BuildShape())->shape(); std::vector dims(shape_vector.begin(), shape_vector.end()); - data_info->format_ = mindspore::NHWC; + auto ret = SetFormatForCnode(cnode, index, fmk_type, train_flag, data_info); + if (ret != RET_OK) { + MS_LOG(ERROR) << "set format for cnode failed"; + return RET_ERROR; + } data_info->data_type_ = type_ptr->type_id(); data_info->shape_ = dims; data_info->node_type_ = NodeType_CNode; diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index c84e2d8f80b..a8c6b3618ed 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -36,6 +36,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/parser/unused_node_remove_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/parser/conv1d_inout_adjust.cc ${CMAKE_CURRENT_SOURCE_DIR}/parser/inputs_adjust.cc + ${CMAKE_CURRENT_SOURCE_DIR}/parser/insert_transpose.cc ${CMAKE_CURRENT_SOURCE_DIR}/import/mindspore_importer.cc ${CMAKE_CURRENT_SOURCE_DIR}/import/primitive_adjust.cc ${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_adjust.cc @@ -73,6 +74,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/fusion/tf_gelu_fusion.cc ../optimizer/fusion/onnx_gelu_fusion.cc ../optimizer/fusion/squeeze_fusion.cc + ../optimizer/fusion/transpose_fusion.cc ../optimizer/fisson/eliminate_concat_split.cc ../optimizer/fisson/fisson_util.cc ../optimizer/fisson/iter_node_outputs.cc diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 86852c636b0..820476d9e01 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -67,6 +67,7 @@ #include "tools/optimizer/parallel/parallel_pass.h" #include "include/registry/pass_registry.h" #include "tools/optimizer/fisson/multi_conv_split_pass.h" +#include "tools/optimizer/fusion/transpose_fusion.h" using std::string; namespace mindspore::lite { @@ -82,6 +83,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter:: if (!config->trainModel) { // remove quantdtype when awaretraining fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); auto conv_bn_pass = std::make_shared(); @@ -338,7 +340,7 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con auto format_pass = std::make_shared(); format_pass->Init(config->fmk, config->trainModel); - if (!format_pass->RunOnlyForShape(old_graph)) { + if (!format_pass->Run(old_graph)) { MS_LOG(ERROR) << "Run format pass failed."; return nullptr; } diff --git a/mindspore/lite/tools/converter/import/mindspore_importer.cc b/mindspore/lite/tools/converter/import/mindspore_importer.cc index 0e6ef9b5eaf..84e2af16871 100644 --- a/mindspore/lite/tools/converter/import/mindspore_importer.cc +++ b/mindspore/lite/tools/converter/import/mindspore_importer.cc @@ -23,6 +23,7 @@ #include "tools/converter/import/mindir_adjust.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/common/tensor_util.h" +#include "tools/converter/parser/insert_transpose.h" namespace mindspore::lite { namespace { @@ -202,6 +203,11 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } + auto insert_transpose = std::make_shared(lite::converter::FmkType_MS, flag.trainModel); + if (!insert_transpose->Run(func_graph)) { + MS_LOG(ERROR) << "Run insert transpose failed."; + return nullptr; + } if ((status = WeightFormatTransform(func_graph)) != RET_OK) { MS_LOG(ERROR) << "WeightFormatTransform failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index 3c948bdc8cc..51628db9cd5 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -31,6 +31,7 @@ #include "tools/converter/quant_param_holder.h" #include "tools/converter/parser/parser_utils.h" #include "tools/optimizer/common/gllo_utils.h" +#include "tools/converter/parser/insert_transpose.h" using mindspore::lite::converter::FmkType_CAFFE; namespace mindspore::lite { @@ -103,6 +104,11 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag) ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } + auto insert_transpose = std::make_shared(lite::converter::FmkType_CAFFE, false); + if (!insert_transpose->Run(res_graph_)) { + MS_LOG(ERROR) << "Run insert transpose failed."; + return nullptr; + } if ((status = WeightFormatTransform(res_graph_)) != RET_OK) { MS_LOG(ERROR) << "WeightFormatTransform failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); diff --git a/mindspore/lite/tools/converter/parser/insert_transpose.cc b/mindspore/lite/tools/converter/parser/insert_transpose.cc new file mode 100644 index 00000000000..5c881836c78 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/insert_transpose.cc @@ -0,0 +1,511 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/insert_transpose.h" +#include +#include +#include +#include +#include "ops/op_utils.h" +#include "src/common/common.h" +#include "src/common/utils.h" +#include "tools/common/tensor_util.h" + +using mindspore::lite::NCHW_SHAPE; +namespace mindspore { +namespace lite { +namespace { +constexpr size_t kNCHWDimNumber = 4; +const std::vector NH2NC = {0, 3, 1, 2}; +const std::vector NC2NH = {0, 2, 3, 1}; +bool IsSpecialType(const CNodePtr &cnode) { + if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || opt::CheckPrimitiveType(cnode, prim::kPrimDepend) || + opt::CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || opt::CheckPrimitiveType(cnode, opt::kPrimMakeTupleV2) || + opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) { + return true; + } + return false; +} +} // namespace + +void InsertTranspose::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) { + MS_ASSERT(cnode != nullptr); + auto prim_node = cnode->input(0); + auto prim = GetValueNode(prim_node); + MS_ASSERT(prim != nullptr); + auto &specify_nhwc_op_map = opt::GetNHWCOpMap(); + auto &specify_nchw_op_map = opt::GetNCHWOpMap(); + if (fmk_type_ == lite::converter::FmkType_TFLITE) { + if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) { + return; + } + trans_info->pre_ = opt::kNHWC2NCHW; + trans_info->post_ = opt::kNCHW2NHWC; + } else if (fmk_type_ == lite::converter::FmkType_TF) { + if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && opt::GetFormat(cnode) == NCHW) { + trans_info->pre_ = opt::kNCHW2NHWC; + trans_info->post_ = opt::kNHWC2NCHW; + } + if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) { + trans_info->pre_ = opt::kNHWC2NCHW; + trans_info->post_ = opt::kNCHW2NHWC; + } + } else { + if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) { + if (fmk_type_ == lite::converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr && + GetValue(prim->GetAttr(ops::kFormat)) == NHWC) { + return; + } + trans_info->pre_ = opt::kNCHW2NHWC; + trans_info->post_ = opt::kNHWC2NCHW; + } + } +} + +AnfNodePtr InsertTranspose::GenNewInputWithoutShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::vector &perm, bool before, size_t index) { + MS_ASSERT(func_graph != nullptr && cnode != nullptr); + AnfNodePtr new_input = nullptr; + AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode; + std::string trans_name = + before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post"; + new_input = opt::GenTransposeNode(func_graph, trans_input_node, perm, trans_name); + auto new_input_prim = GetValueNode(new_input->cast()->input(0)); + if (perm == NC2NH) { + new_input_prim->AddAttr(ops::kFormat, MakeValue(NCHW)); + } else if (perm == NH2NC) { + new_input_prim->AddAttr(ops::kFormat, MakeValue(NHWC)); + } + return new_input; +} + +STATUS InsertTranspose::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector perm, + bool before, size_t index) { + MS_ASSERT(func_graph != nullptr && cnode != nullptr); + AnfNodePtr new_input = nullptr; + + new_input = GenNewInputWithoutShape(func_graph, cnode, perm, before, index); + if (new_input == nullptr) { + MS_LOG(ERROR) << "generate a transpose node failed."; + return lite::RET_ERROR; + } + if (new_input == cnode->input(index) || new_input == cnode) { + return lite::RET_OK; + } + auto manager = func_graph->manager(); + if (manager == nullptr) { + manager = Manage(func_graph, true); + } + MS_ASSERT(manager != nullptr); + auto tr = manager->Transact(); + if (before) { + tr.SetEdge(cnode, index, new_input); + tr.Commit(); + } else { + func_graph->manager()->Replace(cnode, new_input); + } + return lite::RET_OK; +} + +STATUS InsertTranspose::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::vector &perm) { + MS_ASSERT(func_graph != nullptr && cnode != nullptr); + auto prim_node = cnode->input(0); + auto prim = GetValueNode(prim_node); + MS_ASSERT(prim != nullptr); + auto &specify_nhwc_op_map = opt::GetNHWCOpMap(); + auto &specify_nchw_op_map = opt::GetNCHWOpMap(); + if (specify_nhwc_op_map.find(prim->name()) == specify_nhwc_op_map.end() && + specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) { + MS_LOG(ERROR) << "op don't meet nhwc condition."; + return lite::RET_ERROR; + } + std::vector insert_index = specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end() + ? specify_nhwc_op_map.at(prim->name()) + : specify_nchw_op_map.at(prim->name()); + if (insert_index.empty()) { + if (opt::CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr && + GetValue(prim->GetAttr(ops::kMethod)) == static_cast(mindspore::ResizeMethod::NEAREST)) { + insert_index.push_back(1); + } else { + for (size_t i = 1; i < cnode->size(); ++i) { + insert_index.push_back(i); + } + } + } + for (auto &index : insert_index) { + if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) { + MS_LOG(ERROR) << "generate a new input failed."; + return lite::RET_ERROR; + } + } + return lite::RET_OK; +} + +STATUS InsertTranspose::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::vector &perm) { + MS_ASSERT(func_graph != nullptr && cnode != nullptr); + if (!cnode->abstract()->isa()) { + if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) { + MS_LOG(ERROR) << "generate a new input failed."; + return lite::RET_ERROR; + } + } else { + auto node_users = func_graph->manager()->node_users()[cnode]; + for (auto &node_user : node_users) { + auto post_node = node_user.first; + CNodePtr tuple_get_item = nullptr; + if (!opt::CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) { + if (!train_flag_) { + MS_LOG(ERROR) << "post node is invalid."; + return lite::RET_ERROR; + } else { + tuple_get_item = opt::GenTupleGetItemNode(func_graph, cnode, 0); + post_node = tuple_get_item; + func_graph->manager()->Replace(cnode, tuple_get_item); + } + } + if (func_graph->manager()->node_users()[post_node].empty()) { + continue; + } + auto post_cnode = post_node->cast(); + if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) { + MS_LOG(ERROR) << "generate a new input failed."; + return lite::RET_ERROR; + } + if (tuple_get_item != nullptr) { + func_graph->manager()->Replace(tuple_get_item, tuple_get_item->input(1)); + } + } + } + return lite::RET_OK; +} + +STATUS InsertTranspose::HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_ASSERT(func_graph != nullptr && cnode != nullptr); + if (fmk_type_ == lite::converter::FmkType_TF || fmk_type_ == lite::converter::FmkType_TFLITE) { + return lite::RET_NO_CHANGE; + } + for (size_t i = 1; i < cnode->size(); ++i) { + auto node = cnode->input(i); + if (!utils::isa(node)) { + continue; + } + auto param_node = node->cast(); + if (param_node->has_default()) { + continue; + } + auto abstract_base = param_node->abstract(); + if (abstract_base == nullptr) { + MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); + return lite::RET_ERROR; + } + if (!utils::isa(abstract_base)) { + MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); + return lite::RET_ERROR; + } + auto abstract_tensor = utils::cast(abstract_base); + if (!utils::isa(abstract_tensor->BuildShape())) { + MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name(); + return lite::RET_ERROR; + } + auto shape_vector = utils::cast(abstract_tensor->BuildShape())->shape(); + if (shape_vector.size() != 4) { + continue; + } + if (func_graph->get_inputs().size() == 1 && fmk_type_ == lite::converter::FmkType_ONNX && shape_vector[3] == 3 && + shape_vector[1] == -1) { + continue; + } + std::vector new_dims = {shape_vector[NCHW_SHAPE::NCHW_N], shape_vector[NCHW_SHAPE::NCHW_H], + shape_vector[NCHW_SHAPE::NCHW_W], shape_vector[NCHW_SHAPE::NCHW_C]}; + abstract_tensor->set_shape(std::make_shared(new_dims)); + auto trans_cnode = opt::GenTransposeNode(func_graph, param_node, NH2NC, param_node->fullname_with_scope() + "_pre"); + auto new_input_prim = GetValueNode(trans_cnode->cast()->input(0)); + new_input_prim->AddAttr(ops::kFormat, MakeValue(NHWC)); + if (trans_cnode == nullptr) { + MS_LOG(ERROR) << "generate a transpose node failed."; + return lite::RET_ERROR; + } + func_graph->manager()->Replace(param_node, trans_cnode); + } + return lite::RET_OK; +} + +STATUS InsertTranspose::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_ASSERT(func_graph != nullptr && cnode != nullptr); + opt::TransTypePair trans_info; + GetTransNodeFormatType(cnode, &trans_info); + if (trans_info.pre_ == opt::kNONE || trans_info.post_ == opt::kNONE) { + return lite::RET_NO_CHANGE; + } + auto before_perm = trans_info.pre_ == opt::kNHWC2NCHW ? NH2NC : NC2NH; + auto after_perm = trans_info.post_ == opt::kNCHW2NHWC ? NC2NH : NH2NC; + if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) { + MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope(); + return lite::RET_ERROR; + } + if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam) || opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) { + return RET_OK; + } + if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) { + MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope(); + return lite::RET_ERROR; + } + return lite::RET_OK; +} + +void InsertTranspose::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { + MS_ASSERT(cnode != nullptr && sub_graph != nullptr); + auto sub_inputs = sub_graph->get_inputs(); + sub_inputs_map_[sub_graph] = sub_inputs; + for (auto &node : sub_inputs) { + auto param_node = node->cast(); + MS_ASSERT(param_node != nullptr); + auto node_name = node->fullname_with_scope(); + auto last_underline = node_name.find_last_of("_"); + node_name = node_name.substr(0, last_underline); + last_underline = node_name.find_last_of("_"); + auto index = std::stoi(node_name.substr(last_underline + 1)) + 3; + param_node->set_abstract(opt::GetCNodeInputAbstract(cnode, index)->Clone()); + if (utils::isa(cnode->input(index))) { + ShapeVector shape_vec = {-1}; + auto out_cnode = cnode->input(index)->cast(); + MS_ASSERT(trans_cnode != nullptr); + auto out_prim = GetValueNode(out_cnode->input(0)); + if (out_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue(out_prim->GetAttr(opt::kInferDone))) { + param_node->abstract()->set_shape(std::make_shared(shape_vec)); + } + } else { + lite::DataInfo data_info; + if (utils::isa(cnode->input(index))) { + if (cnode->input(index)->cast()->has_default()) { + param_node->set_default_param(cnode->input(index)->cast()->default_param()); + } + continue; + } + auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info); + if (status != lite::RET_OK) { + continue; + } + ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end()); + if (data_info.data_.empty()) { + param_node->set_default_param(std::make_shared((TypeId)data_info.data_type_, shape_vec)); + } else { + param_node->set_default_param(std::make_shared((TypeId)data_info.data_type_, shape_vec, + data_info.data_.data(), data_info.data_.size())); + } + } + } +} + +void InsertTranspose::ResetSubGraphInput() { + for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) { + auto &sub_graph = iter->first; + auto &sub_inputs = iter->second; + auto manager = sub_graph->manager(); + MS_ASSERT(manager != nullptr); + for (auto &sub_input : sub_inputs) { + auto param_node = sub_graph->add_parameter(); + MS_ASSERT(param_node != nullptr); + param_node->set_abstract(sub_input->abstract()->Clone()); + param_node->set_name(sub_input->fullname_with_scope()); + manager->Replace(sub_input, param_node); + auto sub_param_input = sub_input->cast(); + MS_ASSERT(sub_param_input != nullptr); + sub_param_input->set_default_param(nullptr); + } + } +} + +void InsertTranspose::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { + MS_ASSERT(cnode != nullptr && sub_graph != nullptr); + auto return_node = sub_graph->get_return(); + auto origin_input = return_node->inputs(); + lite::RemoveIfDepend(return_node); + lite::RemoveIfMakeTuple(return_node); + for (size_t i = 1; i < return_node->size(); ++i) { + if (!opt::CheckPrimitiveType(return_node->input(i), prim::kPrimTranspose)) { + continue; + } + auto node_name = return_node->input(i)->fullname_with_scope(); + if (node_name.substr(node_name.size() - 5) != "_post") { + continue; + } + auto trans_cnode = return_node->input(i)->cast(); + MS_ASSERT(trans_cnode != nullptr); + auto trans_input = trans_cnode->input(1); + auto trans_input_name = trans_input->fullname_with_scope(); + if (utils::isa(trans_input)) { + trans_input->cast()->set_name(node_name); + } else if (utils::isa(trans_input)) { + trans_input->cast()->set_fullname_with_scope(node_name); + } + trans_input_name = trans_input_name.substr(0, trans_input_name.find_last_of("_")) + "_cnode"; + trans_cnode->set_fullname_with_scope(trans_input_name); + } + return_node->set_inputs(origin_input); +} + +void InsertTranspose::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { + MS_ASSERT(cnode != nullptr && sub_graph != nullptr); + auto return_node = sub_graph->get_return(); + auto origin_inputs = return_node->inputs(); + lite::RemoveIfDepend(return_node); + lite::RemoveIfMakeTuple(return_node); + AbstractBasePtrList abstract_list; + bool infer_done = true; + for (size_t i = 1; i < return_node->size(); ++i) { + auto abstract_base = opt::GetCNodeInputAbstract(return_node, i); + MS_ASSERT(abstract_base != nullptr); + abstract_list.emplace_back(abstract_base->Clone()); + auto abstract_tensor = abstract_base->cast(); + MS_ASSERT(abstract_tensor != nullptr); + auto shape_ptr = utils::cast(abstract_tensor->BuildShape()); + MS_ASSERT(shape_ptr != nullptr); + auto shape = shape_ptr->shape(); + if (std::find(shape.begin(), shape.end(), -1) != shape.end()) { + infer_done = false; + } + if (utils::isa(return_node->input(i))) { + auto input_cnode = return_node->input(i)->cast(); + if (opt::CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) { + input_cnode = input_cnode->input(1)->cast(); + } + auto input_prim = GetValueNode(input_cnode->input(0)); + if (input_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue(input_prim->GetAttr(opt::kInferDone))) { + infer_done = false; + } + } + } + return_node->set_inputs(origin_inputs); + if (utils::isa(cnode->abstract())) { + cnode->set_abstract(std::make_shared(abstract_list)); + } else { + if (abstract_list.size() != 1) { + MS_LOG(ERROR) << "cnode output is invalid."; + } + cnode->set_abstract(abstract_list.front()); + } + auto prim = GetValueNode(cnode->input(0)); + prim->AddAttr(opt::kInferDone, MakeValue(infer_done)); +} + +bool InsertTranspose::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) { + MS_ASSERT(func_graph != nullptr); + auto graph_name = GetValue(func_graph->get_attr("graph_name")); + auto manager = Manage(func_graph, true); + if (manager == nullptr) { + MS_LOG(ERROR) << "manager is nullptr."; + return false; + } + auto node_list = TopoSort(func_graph->get_return()); + int status; + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto cnode = node->cast(); + if (IsSpecialType(cnode)) { + continue; + } + if (main_graph) { + status = HandleGraphInput(func_graph, cnode); + if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { + return false; + } + } + if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) { + auto sub_func_graph = GetValueNode(cnode->input(1)); + if (sub_func_graph == nullptr) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return false; + } + SetSubGraphInput(cnode, sub_func_graph); + (void)BasicProcess(sub_func_graph, false); + SetSubGraphOutput(cnode, sub_func_graph); + sub_func_graph = GetValueNode(cnode->input(2)); + if (sub_func_graph == nullptr) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return false; + } + SetSubGraphInput(cnode, sub_func_graph); + (void)BasicProcess(sub_func_graph, false); + SetSubGraphOutput(cnode, sub_func_graph); + SetSubGraphAbstract(cnode, sub_func_graph); + continue; + } + status = HandleGraphNode(func_graph, cnode); + if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { + return false; + } + } + return true; +} + +bool InsertTranspose::ResetFuncGraph(const FuncGraphPtr &func_graph) { + MS_ASSERT(func_graph != nullptr); + auto manager = Manage(func_graph, true); + if (manager == nullptr) { + MS_LOG(ERROR) << "manager is nullptr."; + return false; + } + auto node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto cnode = node->cast(); + auto prim = GetValueNode(cnode->input(0)); + if (prim->GetAttr(opt::kInferDone) != nullptr) { + prim->EraseAttr(opt::kInferDone); + } + if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) { + auto sub_func_graph = GetValueNode(cnode->input(1)); + if (sub_func_graph == nullptr) { + return false; + } + (void)ResetFuncGraph(sub_func_graph); + sub_func_graph = GetValueNode(cnode->input(2)); + if (sub_func_graph == nullptr) { + return false; + } + (void)ResetFuncGraph(sub_func_graph); + } + } + return true; +} + +bool InsertTranspose::Run(const FuncGraphPtr &func_graph) { + MS_ASSERT(func_graph != nullptr); + auto node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + auto prim = GetValueNode(node); + if (prim == nullptr) { + continue; + } + } + // insert transpose for some ops whose format must be NHWC, which is depend on framework. + // In this process, tranpose can be fused, which the original graph may not be able to restored. + if (!BasicProcess(func_graph, true)) { + MS_LOG(ERROR) << "run framework transpose unify failed."; + return false; + } + ResetSubGraphInput(); + return true; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/insert_transpose.h b/mindspore/lite/tools/converter/parser/insert_transpose.h new file mode 100644 index 00000000000..10039a52926 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/insert_transpose.h @@ -0,0 +1,62 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_ + +#include +#include +#include +#include +#include +#include "utils/utils.h" +#include "tools/converter/converter_flags.h" +#include "tools/optimizer/common/format_utils.h" +#include "tools/anf_exporter/fetch_content.h" + +using mindspore::lite::converter::FmkType; +namespace mindspore { +namespace lite { +class InsertTranspose { + public: + InsertTranspose(FmkType fmk_type, bool train_flag) : fmk_type_(fmk_type), train_flag_(train_flag) {} + ~InsertTranspose() = default; + bool Run(const FuncGraphPtr &func_graph); + STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector &perm); + STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector &perm); + STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector perm, bool before, + size_t index = 0); + + private: + AnfNodePtr GenNewInputWithoutShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::vector &perm, bool before, size_t index); + bool ResetFuncGraph(const FuncGraphPtr &func_graph); + bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph); + void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info); + STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); + STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode); + void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); + void ResetSubGraphInput(); + void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); + void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); + FmkType fmk_type_{lite::converter::FmkType_MS}; + bool train_flag_{false}; + std::unordered_map> sub_inputs_map_; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_ diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 0a2c4694a1f..085f0ab01ca 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -35,6 +35,7 @@ #include "tools/converter/parser/onnx/onnx_pad_adjust.h" #include "tools/converter/parser/parser_utils.h" #include "ops/transpose.h" +#include "tools/converter/parser/insert_transpose.h" using mindspore::lite::converter::FmkType_ONNX; namespace mindspore { @@ -90,6 +91,11 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } + auto insert_transpose = std::make_shared(lite::converter::FmkType_ONNX, false); + if (!insert_transpose->Run(res_graph_)) { + MS_LOG(ERROR) << "Run insert transpose failed."; + return nullptr; + } if ((status = WeightFormatTransform(all_func_graphs)) != RET_OK) { MS_LOG(ERROR) << "WeightFormatTransform failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index 46e1e238cc6..366aa2ab1f2 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -33,6 +33,7 @@ #include "tools/converter/parser/tf/functionalize_control_op_pass.h" #include "tools/converter/parser/parser_utils.h" #include "tools/common/tensor_util.h" +#include "tools/converter/parser/insert_transpose.h" using mindspore::lite::converter::FmkType_TF; namespace mindspore { @@ -575,6 +576,11 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } + auto insert_transpose = std::make_shared(lite::converter::FmkType_TF, false); + if (!insert_transpose->Run(res_graph_)) { + MS_LOG(ERROR) << "Run insert transpose failed."; + return nullptr; + } if ((status = WeightFormatTransform(res_graph_)) != RET_OK) { MS_LOG(ERROR) << "WeightFormatTransform failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 5e1d6c59a5f..05b7af9da16 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -30,6 +30,7 @@ #include "tools/converter/converter_flags.h" #include "tools/converter/parser/tflite/tflite_inputs_adjust.h" #include "tools/converter/parser/parser_utils.h" +#include "tools/converter/parser/insert_transpose.h" using mindspore::lite::converter::FmkType_TFLITE; namespace mindspore::lite { @@ -104,6 +105,11 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } + auto insert_transpose = std::make_shared(lite::converter::FmkType_TFLITE, false); + if (!insert_transpose->Run(res_graph_)) { + MS_LOG(ERROR) << "Run insert transpose failed."; + return nullptr; + } if ((status = WeightFormatTransform(res_graph_)) != RET_OK) { MS_LOG(ERROR) << "WeightFormatTransform failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); diff --git a/mindspore/lite/tools/optimizer/common/format_utils.h b/mindspore/lite/tools/optimizer/common/format_utils.h index 98e2b3efa03..ebca3016f71 100644 --- a/mindspore/lite/tools/optimizer/common/format_utils.h +++ b/mindspore/lite/tools/optimizer/common/format_utils.h @@ -26,7 +26,6 @@ namespace mindspore { namespace opt { constexpr auto kInferDone = "infer_done"; -constexpr auto kTransDone = "trans_done"; enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE }; struct TransTypePair { FormatTransNodeType pre_; diff --git a/mindspore/lite/tools/optimizer/fusion/transpose_fusion.cc b/mindspore/lite/tools/optimizer/fusion/transpose_fusion.cc new file mode 100644 index 00000000000..60594c8b393 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/transpose_fusion.cc @@ -0,0 +1,176 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/fusion/transpose_fusion.h" +#include +#include +#include +#include "tools/converter/quant_param_holder.h" +#include "mindspore/core/ops/transpose.h" +#include "tools/optimizer/common/gllo_utils.h" + +namespace mindspore::opt { +namespace { +const std::vector NH2NC = {0, 3, 1, 2}; +const std::vector NC2NH = {0, 2, 3, 1}; +} // namespace +bool IsBNCNode(const BaseRef &n) { + if (utils::isa(n)) { + auto anf_node = utils::cast(n); + return CheckPrimitiveType(anf_node, prim::kPrimBatchNorm) || + CheckPrimitiveType(anf_node, prim::kPrimFusedBatchNorm); + } + return false; +} + +VectorRef TransposeFusion::DefineBNPattern() const { + auto transpose_var = std::make_shared(IsSpecifiedNode<&prim::kPrimTranspose>); + auto conv_var = std::make_shared(IsConvNode); + auto transpose_param = std::make_shared(IsParamNode); + VectorRef transpose_conv_ref = VectorRef({transpose_var, conv_var, transpose_param}); + auto bn_var = std::make_shared(IsBNCNode); + auto bn_mean_var = std::make_shared(IsParamNode); + auto bn_variable_var = std::make_shared(IsParamNode); + auto bn_other_var = std::make_shared(); + VectorRef bn_ref = VectorRef({bn_var, transpose_conv_ref, bn_mean_var, bn_variable_var, bn_other_var}); + return bn_ref; +} + +VectorRef TransposeFusion::DefineActivationscalePattern() const { + auto transpose_var = std::make_shared(IsSpecifiedNode<&prim::kPrimTranspose>); + auto conv_var = std::make_shared(IsConvNode); + auto transpose_param = std::make_shared(IsParamNode); + VectorRef transpose_conv_ref = VectorRef({transpose_var, conv_var, transpose_param}); + auto scale_var = std::make_shared(IsSpecifiedNode<&prim::kPrimScaleFusion>); + auto scale_var_1 = std::make_shared(IsParamNode); + auto scale_var_2 = std::make_shared(); + VectorRef sclae_ref = VectorRef({scale_var, transpose_conv_ref, scale_var_1, scale_var_2}); + return sclae_ref; +} + +VectorRef TransposeFusion::DefineActivationPattern() const { + auto transpose_var = std::make_shared(IsSpecifiedNode<&prim::kPrimTranspose>); + auto conv_var = std::make_shared(IsConvNode); + auto transpose_param = std::make_shared(IsParamNode); + VectorRef transpose_conv_ref = VectorRef({transpose_var, conv_var, transpose_param}); + auto act_var = std::make_shared(IsSpecifiedNode<&prim::kPrimActivation>); + VectorRef act_ref = VectorRef({act_var, transpose_conv_ref}); + return act_ref; +} + +VectorRef TransposeFusion::DefineBiasAddPattern() const { + auto transpose_var = std::make_shared(IsSpecifiedNode<&prim::kPrimTranspose>); + auto conv_var = std::make_shared(IsConvNode); + auto transpose_param = std::make_shared(IsParamNode); + VectorRef transpose_conv_ref = VectorRef({transpose_var, conv_var, transpose_param}); + auto bias_var = std::make_shared(IsSpecifiedNode<&prim::kPrimBiasAdd>); + auto bias_param = std::make_shared(IsParamNode); + VectorRef act_ref = VectorRef({bias_var, transpose_conv_ref, bias_param}); + return act_ref; +} + +VectorRef TransposeFusion::DefineTransTransPattern() const { + auto transpose_var_1 = std::make_shared(IsSpecifiedNode<&prim::kPrimTranspose>); + auto transpose_var_2 = std::make_shared(IsSpecifiedNode<&prim::kPrimTranspose>); + auto transpose_param = std::make_shared(IsParamNode); + VectorRef trans_trans_ref = VectorRef({transpose_var_2, transpose_var_1, transpose_param}); + return trans_trans_ref; +} + +std::unordered_map TransposeFusion::DefinePatterns() const { + std::unordered_map patterns; + patterns["BNPatternName"] = DefineBNPattern(); + patterns["ActivationPatternName"] = DefineActivationPattern(); + patterns["BiasAddPatternName"] = DefineBiasAddPattern(); + patterns["ScalePatternName"] = DefineActivationscalePattern(); + patterns["TransTransPatternName"] = DefineTransTransPattern(); + return patterns; +} + +CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const AnfNodePtr &perm, + const std::string &cnode_name) { + MS_ASSERT(func_graph != nullptr && input_node != nullptr); + auto trans_prim = std::make_shared(); + MS_ASSERT(trans_prim != nullptr); + auto cnode = func_graph->NewCNode(trans_prim, {input_node, perm}); + MS_ASSERT(cnode != nullptr); + cnode->set_fullname_with_scope(cnode_name); + auto quant_params_holder = std::make_shared(2, 1); + auto trans_insert_prim = GetValueNode(cnode->input(0)); + trans_insert_prim->AddAttr("quant_params", quant_params_holder); + return cnode; +} + +AnfNodePtr TransposeFusion::TransTransFusion(const mindspore::FuncGraphPtr &func_graph, + const mindspore::AnfNodePtr &node) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(node != nullptr); + if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return nullptr; + } + auto trans_cnode_2 = node->cast(); + if (!CheckPrimitiveType(trans_cnode_2, prim::kPrimTranspose) || + !CheckPrimitiveType(trans_cnode_2->input(1), prim::kPrimTranspose)) { + return nullptr; + } + std::vector post_perm; + if (GetTransposePerm(trans_cnode_2, &post_perm) != lite::RET_OK) { + MS_LOG(ERROR) << "get tanspose perm failed."; + return nullptr; + } + std::vector pre_perm; + auto pre_node = trans_cnode_2->input(1); + auto pre_cnode = pre_node->cast(); + if (pre_cnode == nullptr) { + return nullptr; + } + if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) { + MS_LOG(ERROR) << "get tanspose perm failed."; + return nullptr; + } + if ((pre_perm == NH2NC && post_perm == NC2NH) || (pre_perm == NC2NH && post_perm == NH2NC)) { + return pre_cnode->input(1); + } + return nullptr; +} + +AnfNodePtr TransposeFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph, + const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const { + if (pattern_name == "TransTransPatternName") { + return TransTransFusion(func_graph, node); + } + if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return nullptr; + } + if (CheckIfCNodeIsNull(node->cast()) != lite::RET_OK) { + return nullptr; + } + auto any_cnode = node->cast(); + const auto transpose_node = any_cnode->input(1); + if (CheckIfCNodeIsNull(transpose_node->cast()) != lite::RET_OK) { + return nullptr; + } + const CNodePtr &transpose_cnode = transpose_node->cast(); + auto perm_node = transpose_cnode->input(2); + auto trans_post_node = GenTransposeNode(func_graph, any_cnode, perm_node, any_cnode->fullname_with_scope() + "_post"); + auto tr = func_graph->manager()->Transact(); + tr.SetEdge(any_cnode, 1, transpose_cnode->input(1)); + tr.Commit(); + return trans_post_node; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/fusion/transpose_fusion.h b/mindspore/lite/tools/optimizer/fusion/transpose_fusion.h new file mode 100644 index 00000000000..e28382b74c2 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/transpose_fusion.h @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_TRANSPOSE_FUSION_H_ +#define MINDSPORE_LITE_SRC_PASS_FUSION_TRANSPOSE_FUSION_H_ + +#include +#include +#include "tools/optimizer/graph/unify_format_pass.h" +#include "backend/optimizer/common/optimizer.h" +#include "schema/inner/model_generated.h" +#include "tools/optimizer/common/multiple_pattern_process_pass.h" + +namespace mindspore { +namespace opt { +class TransposeFusion : public MultiplePatternProcessPass { + public: + explicit TransposeFusion(const std::string &name = "transpose_fusion", bool multigraph = true) + : MultiplePatternProcessPass(name, multigraph) {} + + ~TransposeFusion() override = default; + + std::unordered_map DefinePatterns() const override; + VectorRef DefineBNPattern() const; + VectorRef DefineActivationPattern() const; + VectorRef DefineActivationscalePattern() const; + VectorRef DefineTransTransPattern() const; + VectorRef DefineBiasAddPattern() const; + AnfNodePtr TransTransFusion(const mindspore::FuncGraphPtr &func_graph, const mindspore::AnfNodePtr &node) const; + AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &, + const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_PASS_FUSION_TRANSPOSE_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/graph/node_infershape.cc b/mindspore/lite/tools/optimizer/graph/node_infershape.cc index 1daf78d7502..4c3adc83e06 100644 --- a/mindspore/lite/tools/optimizer/graph/node_infershape.cc +++ b/mindspore/lite/tools/optimizer/graph/node_infershape.cc @@ -165,6 +165,8 @@ STATUS NodeInferShape::InferShape(const CNodePtr &cnode) { } if (ret == lite::RET_OK || ret == lite::RET_INFER_INVALID) { auto set_status = SetCNodeAbstract(cnode, outputs, ret); + auto cnode_prim = GetValueNode(cnode->input(0)); + cnode_prim->AddAttr(ops::kFormat, MakeValue(inputs[0]->format())); if (set_status != lite::RET_OK) { MS_LOG(ERROR) << "set CNode abstract failed: " << cnode->fullname_with_scope(); FreeTensors(&inputs); @@ -336,6 +338,7 @@ STATUS NodeInferShape::GetCNodeVarInput(const CNodePtr &cnode, std::vectorset_format((Format)(data_info.format_)); } if (tensor == nullptr) { MS_LOG(ERROR) << "new a lite tensor failed"; diff --git a/mindspore/lite/tools/optimizer/graph/node_infershape.h b/mindspore/lite/tools/optimizer/graph/node_infershape.h index f65cab79e18..9809c8fd587 100644 --- a/mindspore/lite/tools/optimizer/graph/node_infershape.h +++ b/mindspore/lite/tools/optimizer/graph/node_infershape.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "schema/inner/model_generated.h" #include "src/tensor.h" #include "tools/anf_exporter/fetch_content.h" @@ -41,9 +42,9 @@ class NodeInferShape { bool JudgeOpSupportInfer(const CNodePtr &cnode); std::vector GetInputShape(const CNodePtr &cnode, size_t index); std::vector GetIntVecInput(const CNodePtr &cnode, size_t index); + STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector *inputs); private: - STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector *inputs); STATUS GetCNodeConstInput(const CNodePtr &cnode, std::vector *const_ms_inputs); STATUS GetCNodeVarInput(const CNodePtr &cnode, std::vector *var_ms_inputs); lite::Tensor *GetCNodeTensorListVarInput(const lite::DataInfo &data_info); diff --git a/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc index d368bebe851..f0d4a56b618 100644 --- a/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc @@ -32,7 +32,13 @@ constexpr size_t kInputTripleNum = 3; int ProcessInputIsMonad(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { MS_ASSERT(func_graph != nullptr && cnode != nullptr); auto first_input = cnode->input(1); + if (CheckPrimitiveType(first_input, prim::kPrimTranspose)) { + first_input = cnode->input(1)->cast()->input(1); + } auto second_input = cnode->input(2); + if (CheckPrimitiveType(second_input, prim::kPrimTranspose)) { + second_input = cnode->input(2)->cast()->input(1); + } AnfNodePtr must_monad = nullptr; AnfNodePtr not_must_monad = nullptr; if (utils::isa(first_input)) { @@ -72,6 +78,12 @@ int ProcessDependencyWithTwoNodes(const FuncGraphPtr &func_graph, const CNodePtr pre_node = cnode->input(2); post_node = cnode->input(1); } + if (CheckPrimitiveType(pre_node, prim::kPrimTranspose)) { + pre_node = cnode->input(1)->cast()->input(1); + } + if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) { + post_node = cnode->input(2)->cast()->input(1); + } auto manager = func_graph->manager(); MS_ASSERT(manager != nullptr); auto node_users = manager->node_users()[pre_node]; @@ -102,6 +114,10 @@ int ProcessInputHaveDependency(const FuncGraphPtr &func_graph, const CNodePtr &c auto make_tuple_prim = NewValueNode(std::make_shared()); auto manager = func_graph->manager(); MS_ASSERT(manager != nullptr); + if (CheckPrimitiveType(cnode->input(0), prim::kPrimTranspose)) { + manager->Replace(cnode->input(0)->cast()->input(0), make_tuple_prim); + return RET_OK; + } manager->Replace(cnode->input(0), make_tuple_prim); return lite::RET_OK; } diff --git a/mindspore/lite/tools/optimizer/graph/unify_format_pass.cc b/mindspore/lite/tools/optimizer/graph/unify_format_pass.cc index 7b323b83ee2..23c780ec53a 100644 --- a/mindspore/lite/tools/optimizer/graph/unify_format_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/unify_format_pass.cc @@ -335,6 +335,9 @@ STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodeP } } auto manager = func_graph->manager(); + if (manager == nullptr) { + manager = Manage(func_graph, true); + } MS_ASSERT(manager != nullptr); auto tr = manager->Transact(); if (before) { @@ -428,6 +431,7 @@ STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const } } status = node_infer_shape_.InferShape(cnode); + if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { MS_LOG(ERROR) << "infer shape failed."; return lite::RET_ERROR; @@ -523,47 +527,8 @@ STATUS UnifyFormatPass::HandleGraphInput(const FuncGraphPtr &func_graph, const C MS_LOG(ERROR) << "infer shape failed."; return lite::RET_ERROR; } - func_graph->manager()->Replace(param_node, trans_cnode); - } - return lite::RET_OK; -} -STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { - MS_ASSERT(func_graph != nullptr && cnode != nullptr); - auto prim_node = cnode->input(0); - auto prim = GetValueNode(prim_node); - MS_ASSERT(prim != nullptr); - if (prim->GetAttr(kTransDone) != nullptr && GetValue(prim->GetAttr(kTransDone))) { - return lite::RET_OK; - } - prim->AddAttr(kTransDone, MakeValue(true)); - TransTypePair trans_info; - GetTransNodeFormatType(cnode, &trans_info); - if (trans_info.pre_ == kNONE || trans_info.post_ == kNONE) { - if (!need_reset_ && TransTransFusion(func_graph, cnode)) { - return lite::RET_OK; - } - auto status = node_infer_shape_.InferShape(cnode); - if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { - MS_LOG(ERROR) << "infer shape failed: " << cnode->fullname_with_scope(); - return lite::RET_ERROR; - } - return lite::RET_NO_CHANGE; - } - auto before_perm = trans_info.pre_ == kNHWC2NCHW ? NH2NC : NC2NH; - auto after_perm = trans_info.post_ == kNCHW2NHWC ? NC2NH : NH2NC; - if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) { - MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope(); - return lite::RET_ERROR; - } - auto status = node_infer_shape_.InferShape(cnode); - if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { - MS_LOG(ERROR) << "infer shape failed."; - return lite::RET_ERROR; - } - if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) { - MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope(); - return lite::RET_ERROR; + func_graph->manager()->Replace(param_node, trans_cnode); } return lite::RET_OK; } @@ -763,58 +728,6 @@ void UnifyFormatPass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraph prim->AddAttr(kInferDone, MakeValue(infer_done)); } -bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) { - MS_ASSERT(func_graph != nullptr); - auto graph_name = GetValue(func_graph->get_attr("graph_name")); - auto manager = Manage(func_graph, true); - if (manager == nullptr) { - MS_LOG(ERROR) << "manager is nullptr."; - return false; - } - auto node_list = TopoSort(func_graph->get_return()); - int status; - for (auto &node : node_list) { - if (!utils::isa(node)) { - continue; - } - auto cnode = node->cast(); - if (IsSpecialType(cnode)) { - continue; - } - if (main_graph && !need_reset_) { - status = HandleGraphInput(func_graph, cnode); - if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { - return false; - } - } - if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { - auto sub_func_graph = GetValueNode(cnode->input(1)); - if (sub_func_graph == nullptr) { - lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); - return false; - } - SetSubGraphInput(cnode, sub_func_graph); - (void)BasicProcess(sub_func_graph, false); - SetSubGraphOutput(cnode, sub_func_graph); - sub_func_graph = GetValueNode(cnode->input(2)); - if (sub_func_graph == nullptr) { - lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); - return false; - } - SetSubGraphInput(cnode, sub_func_graph); - (void)BasicProcess(sub_func_graph, false); - SetSubGraphOutput(cnode, sub_func_graph); - SetSubGraphAbstract(cnode, sub_func_graph); - continue; - } - status = HandleGraphNode(func_graph, cnode); - if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { - return false; - } - } - return true; -} - bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) { MS_ASSERT(func_graph != nullptr); auto graph_name = GetValue(func_graph->get_attr("graph_name")); @@ -935,33 +848,6 @@ bool UnifyFormatPass::ResetFuncGraph(const FuncGraphPtr &func_graph) { if (prim->GetAttr(kInferDone) != nullptr) { prim->EraseAttr(kInferDone); } - if (prim->GetAttr(kTransDone) != nullptr) { - prim->EraseAttr(kTransDone); - } - if (pre_insert_trans_.find(cnode) != pre_insert_trans_.end()) { - manager->Replace(node, cnode->input(1)); - } - if (post_insert_trans_.find(cnode) != post_insert_trans_.end()) { - auto cnode_abstract = cnode->abstract(); - if (!utils::isa(cnode_abstract)) { - MS_LOG(ERROR) << "abstract is not abstract tensor."; - return false; - } - auto cnode_abstract_tensor = cnode_abstract->cast(); - if (!utils::isa(cnode_abstract_tensor->BuildShape())) { - MS_LOG(ERROR) << "shape of abstract tensor should be ShapePtr."; - return false; - } - auto shape_ptr = utils::cast(cnode_abstract_tensor->BuildShape()); - auto input_abstract = GetCNodeInputAbstract(cnode, 1); - if (!utils::isa(input_abstract)) { - MS_LOG(ERROR) << "abstract is not abstract tensor."; - return false; - } - auto input_abstract_tensor = input_abstract->cast(); - input_abstract_tensor->set_shape(shape_ptr); - manager->Replace(node, cnode->input(1)); - } if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { auto sub_func_graph = GetValueNode(cnode->input(1)); if (sub_func_graph == nullptr) { @@ -1025,18 +911,140 @@ bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) { MS_LOG(ERROR) << "exist op cannot support infer shape."; return false; } - need_reset_ = true; - // insert transpose for some ops whose format must be NHWC, which is depend on framework. - // In this process, transpose op cannot be fused to restore the original graph. - if (!BasicProcess(func_graph, true)) { - MS_LOG(ERROR) << "run framework transpose unify failed."; + if (!RunNodeInferShape(func_graph)) { + MS_LOG(ERROR) << "RunNodeInferShape failed."; return false; } ResetSubGraphInput(); - // delete insert transpose op and update op output shape. - if (!ResetFuncGraph(func_graph)) { - MS_LOG(ERROR) << "reset func_graph failed."; - return false; + ResetFuncGraph(func_graph); + return true; +} + +bool UnifyFormatPass::RunNodeInferShape(const FuncGraphPtr &func_graph) { + auto node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto cnode = node->cast(); + if (IsSpecialType(cnode)) { + continue; + } + if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) { + auto sub_func_graph = GetValueNode(cnode->input(1)); + if (sub_func_graph == nullptr) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return false; + } + SetSubGraphInput(cnode, sub_func_graph); + if (!RunNodeInferShape(sub_func_graph)) { + MS_LOG(ERROR) << "subgraph infer shape failed."; + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR); + return false; + } + SetSubGraphOutput(cnode, sub_func_graph); + + sub_func_graph = GetValueNode(cnode->input(2)); + if (sub_func_graph == nullptr) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return false; + } + SetSubGraphInput(cnode, sub_func_graph); + if (!RunNodeInferShape(sub_func_graph)) { + MS_LOG(ERROR) << "subgraph infer shape failed."; + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR); + return false; + } + SetSubGraphOutput(cnode, sub_func_graph); + SetSubGraphAbstract(cnode, sub_func_graph); + continue; + } + auto status = node_infer_shape_.InferShape(cnode); + if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { + MS_LOG(ERROR) << "infer shape failed." << cnode->fullname_with_scope(); + return false; + } + } + return true; +} + +bool UnifyFormatPass::RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + auto prim_node = cnode->input(0); + auto prim = GetValueNode(prim_node); + auto &nchw_op = GetNCHWOpMap(); + if (!utils::isa(cnode->input(1))) { + return true; + } + if (utils::isa(cnode->input(1))) { + auto format = GetValue(prim->GetAttr(ops::kFormat)); + if (nchw_op.find(prim->name()) != nchw_op.end() && format != NCHW) { + InsertPreTransNode(func_graph, cnode, {0, 3, 1, 2}); + InsertPostTransNode(func_graph, cnode, {0, 2, 3, 1}); + } + } + { + if (CheckPrimitiveType(cnode, prim::kPrimTranspose)) { + MS_ASSERT(func_graph != nullptr && cnode != nullptr); + auto manager = func_graph->manager(); + if (manager == nullptr) { + manager = Manage(func_graph, true); + } + auto shape = node_infer_shape_.GetInputShape(cnode, 1); + std::vector perm; + auto status = GetTransposePerm(cnode, &perm); + if (status != RET_OK) { + return false; + } + if (!shape.empty() && shape.size() != perm.size()) { + manager->Replace(cnode, cnode->input(1)); + } + } + } + return true; +} + +bool UnifyFormatPass::DoFixFormat(const FuncGraphPtr &func_graph) { + auto node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto cnode = node->cast(); + if (IsSpecialType(cnode)) { + continue; + } + if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) { + auto sub_func_graph = GetValueNode(cnode->input(1)); + if (sub_func_graph == nullptr) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return false; + } + SetSubGraphInput(cnode, sub_func_graph); + if (!DoFixFormat(sub_func_graph)) { + MS_LOG(ERROR) << "subgraph infer shape failed."; + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR); + return false; + } + SetSubGraphOutput(cnode, sub_func_graph); + + sub_func_graph = GetValueNode(cnode->input(2)); + if (sub_func_graph == nullptr) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return false; + } + SetSubGraphInput(cnode, sub_func_graph); + if (!DoFixFormat(sub_func_graph)) { + MS_LOG(ERROR) << "subgraph infer shape failed."; + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR); + return false; + } + SetSubGraphOutput(cnode, sub_func_graph); + SetSubGraphAbstract(cnode, sub_func_graph); + continue; + } + if (!RunDoFixFormat(func_graph, cnode)) { + return false; + } } return true; } @@ -1049,22 +1057,23 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) { if (prim == nullptr) { continue; } - if (prim->GetAttr(kTransDone) != nullptr) { - return true; - } } if (!JudgeAllOpsCanInfer(func_graph)) { MS_LOG(ERROR) << "exist op cannot support infer shape."; return false; } - // insert transpose for some ops whose format must be NHWC, which is depend on framework. - // In this process, tranpose can be fused, which the original graph may not be able to restored. - if (!BasicProcess(func_graph, true)) { - MS_LOG(ERROR) << "run framework transpose unify failed."; + if (!RunNodeInferShape(func_graph)) { + MS_LOG(ERROR) << "infer shape failed."; return false; } ResetSubGraphInput(); - // if input format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op. + + if (!DoFixFormat(func_graph)) { + MS_LOG(ERROR) << "DoFixFormat failed."; + return false; + } + ResetSubGraphInput(); + if (!DecreaseTransposeForSingleOp(func_graph)) { MS_LOG(ERROR) << "run local trans insert optimizer failed."; return false; diff --git a/mindspore/lite/tools/optimizer/graph/unify_format_pass.h b/mindspore/lite/tools/optimizer/graph/unify_format_pass.h index ec0632e3c99..c9b3df53ec6 100644 --- a/mindspore/lite/tools/optimizer/graph/unify_format_pass.h +++ b/mindspore/lite/tools/optimizer/graph/unify_format_pass.h @@ -45,23 +45,25 @@ class UnifyFormatPass : public Pass { bool RunOnlyForShape(const FuncGraphPtr &func_graph); private: + STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector &perm); + STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector &perm); + STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector perm, bool before, + size_t index = 0); + bool RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode); + bool DoFixFormat(const FuncGraphPtr &func_graph); + bool RunNodeInferShape(const FuncGraphPtr &func_graph); bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph); bool ResetFuncGraph(const FuncGraphPtr &func_graph); - bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph); bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph); bool DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph); bool TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode); STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode); - STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector perm, bool before, - size_t index = 0); + void GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info); STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); - STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode); STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::set *visit_transposes); - STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector &perm); STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info); - STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector &perm); void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); void ResetSubGraphInput(); void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);