!20196 [LITE] format

Merge pull request !20196 from yefeng/130-format-nnie
This commit is contained in:
i-robot 2021-07-21 06:25:03 +00:00 committed by Gitee
commit c0179f61d8
20 changed files with 1065 additions and 168 deletions

View File

@ -27,20 +27,6 @@ bool CheckPermTransFormat(const int *perm, const int *perm_transformat, const si
} }
int SetOutputShape(int perms_num, const TensorC *input, TensorC *output, int *perm, size_t perm_size, int *out_shape) { int SetOutputShape(int perms_num, const TensorC *input, TensorC *output, int *perm, size_t perm_size, int *out_shape) {
if (perms_num == 4) {
const int nchw2nhwc[4] = {0, 2, 3, 1};
const int nhwc2nchw[4] = {0, 3, 1, 2};
const int trans3d[3] = {0, 2, 1};
if (input->format_ == Format_NCHW && CheckPermTransFormat(perm, nchw2nhwc, perms_num)) {
output->format_ = Format_NHWC;
} else if (input->format_ == Format_NHWC && CheckPermTransFormat(perm, nhwc2nchw, perms_num)) {
output->format_ = Format_NCHW;
}
// though the perm is 4d in default, the input can be a 3d tensor. The op implementation should be adapted to this.
if (input->shape_size_ == 3) {
ShapeSet(perm, &perm_size, trans3d, 3);
}
}
// set output shape // set output shape
size_t in_shape_size = input->shape_size_; size_t in_shape_size = input->shape_size_;
output->shape_size_ = in_shape_size; output->shape_size_ = in_shape_size;
@ -76,13 +62,6 @@ int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
TensorC *output = outputs[0]; TensorC *output = outputs[0];
SetDataTypeFormat(output, input); SetDataTypeFormat(output, input);
if (parameter->quant_type_ == QuantType_QUANT_WEIGHT) {
output->data_type_ = kNumberTypeFloat32;
}
if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}
const TensorC *perm_tensor = inputs[1]; const TensorC *perm_tensor = inputs[1];
const int32_t *perm_data = (int32_t *)perm_tensor->data_; const int32_t *perm_data = (int32_t *)perm_tensor->data_;
const size_t perms_num = (size_t)perm_tensor->shape_[0]; const size_t perms_num = (size_t)perm_tensor->shape_[0];
@ -97,6 +76,28 @@ int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
for (size_t i = 0; i < perms_num; i++) { for (size_t i = 0; i < perms_num; i++) {
ShapePush(perm, &perm_size, perm_data[i]); ShapePush(perm, &perm_size, perm_data[i]);
} }
if (perms_num == 4) {
const int nchw2nhwc[4] = {0, 2, 3, 1};
const int nhwc2nchw[4] = {0, 3, 1, 2};
const int trans3d[3] = {0, 2, 1};
if (input->format_ == Format_NCHW && CheckPermTransFormat(perm, nchw2nhwc, perms_num)) {
output->format_ = Format_NHWC;
} else if ((input->format_ == Format_NHWC || input->format_ == Format_KHWC) &&
CheckPermTransFormat(perm, nhwc2nchw, perms_num)) {
output->format_ = Format_NCHW;
}
// though the perm is 4d in default, the input can be a 3d tensor. The op implementation should be adapted to this.
if (input->shape_size_ == 3) {
ShapeSet(perm, &perm_size, trans3d, 3);
}
}
if (parameter->quant_type_ == QuantType_QUANT_WEIGHT) {
output->data_type_ = kNumberTypeFloat32;
}
if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}
// set output shape // set output shape
int out_shape[MAX_TRANSPOSE_DIM_SIZE] = {0}; int out_shape[MAX_TRANSPOSE_DIM_SIZE] = {0};
SetOutputShape(perms_num, input, output, perm, perm_size, out_shape); SetOutputShape(perms_num, input, output, perm, perm_size, out_shape);

View File

@ -247,6 +247,7 @@ if(MSLITE_ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/squeeze_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/squeeze_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/transpose_fusion.cc
${LITE_DIR}/tools/optimizer/graph/add_tensor_array.cc ${LITE_DIR}/tools/optimizer/graph/add_tensor_array.cc
${LITE_DIR}/tools/optimizer/graph/conv1d_weight_expanding_pass.cc ${LITE_DIR}/tools/optimizer/graph/conv1d_weight_expanding_pass.cc
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
@ -282,6 +283,7 @@ if(MSLITE_ENABLE_CONVERTER)
${LITE_DIR}/tools/common/node_util.cc ${LITE_DIR}/tools/common/node_util.cc
${LITE_DIR}/tools/common/storage.cc ${LITE_DIR}/tools/common/storage.cc
${LITE_DIR}/tools/converter/parser/inputs_adjust.cc ${LITE_DIR}/tools/converter/parser/inputs_adjust.cc
${LITE_DIR}/tools/converter/parser/insert_transpose.cc
${LITE_DIR}/tools/converter/parser/unused_node_remove_pass.cc ${LITE_DIR}/tools/converter/parser/unused_node_remove_pass.cc
${LITE_DIR}/tools/converter/parser/conv1d_inout_adjust.cc ${LITE_DIR}/tools/converter/parser/conv1d_inout_adjust.cc
${LITE_DIR}/tools/converter/parser/tf_bidirection_gru_cf_fusion.cc ${LITE_DIR}/tools/converter/parser/tf_bidirection_gru_cf_fusion.cc

View File

@ -22,6 +22,7 @@
#include "tools/converter/quant_param_holder.h" #include "tools/converter/quant_param_holder.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "utils/check_convert_utils.h" #include "utils/check_convert_utils.h"
#include "tools/optimizer/common/format_utils.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -284,12 +285,15 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
return RET_ERROR; return RET_ERROR;
} }
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->GetAttr(ops::kFormat) != nullptr) { if (prim->GetAttr(ops::kFormat) != nullptr && !opt::CheckPrimitiveType(cnode, prim::kPrimResize)) {
auto value = prim->GetAttr(ops::kFormat); auto value = prim->GetAttr(ops::kFormat);
if (value->isa<mindspore::Int64Imm>()) { if (value->isa<mindspore::Int64Imm>()) {
data_info->format_ = GetValue<int64_t>(value); data_info->format_ = GetValue<int64_t>(value);
} }
} }
if (!param_node->has_default()) {
data_info->format_ = NHWC;
}
// attr weightFormat is only used by conv-like ops' second input // attr weightFormat is only used by conv-like ops' second input
if ((opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || if ((opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) || opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) ||
@ -347,6 +351,31 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy
return ret; return ret;
} }
int SetFormatForCnode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
DataInfo *data_info) {
data_info->format_ = mindspore::NHWC;
auto input_node_prim = GetValueNode<PrimitivePtr>((cnode->input(index)->cast<CNodePtr>()->input(0)));
if (input_node_prim->GetAttr(ops::kFormat) != nullptr) {
auto value = input_node_prim->GetAttr(ops::kFormat);
if (value->isa<mindspore::Int64Imm>()) {
data_info->format_ = GetValue<int64_t>(value);
}
}
if (opt::CheckPrimitiveType(cnode->input(index), prim::kPrimTranspose)) {
std::vector<int> perm;
if (opt::GetTransposePerm(cnode->input(index)->cast<CNodePtr>(), &perm) != RET_OK) {
return RET_ERROR;
}
if (perm[0] == 0 && perm[1] == 3 && perm[2] == 1 && perm[3] == 2 &&
(data_info->format_ == NHWC || data_info->format_ == KHWC)) {
data_info->format_ = NCHW;
} else if (perm[0] == 0 && perm[1] == 2 && perm[2] == 3 && perm[3] == 1 && data_info->format_ == NCHW) {
data_info->format_ = NHWC;
}
}
return RET_OK;
}
int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag, int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
DataInfo *data_info) { DataInfo *data_info) {
MS_ASSERT(cnode != nullptr && data_info != nullptr); MS_ASSERT(cnode != nullptr && data_info != nullptr);
@ -368,7 +397,11 @@ int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType f
} }
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end()); std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
data_info->format_ = mindspore::NHWC; auto ret = SetFormatForCnode(cnode, index, fmk_type, train_flag, data_info);
if (ret != RET_OK) {
MS_LOG(ERROR) << "set format for cnode failed";
return RET_ERROR;
}
data_info->data_type_ = type_ptr->type_id(); data_info->data_type_ = type_ptr->type_id();
data_info->shape_ = dims; data_info->shape_ = dims;
data_info->node_type_ = NodeType_CNode; data_info->node_type_ = NodeType_CNode;

View File

@ -36,6 +36,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/parser/unused_node_remove_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/parser/unused_node_remove_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/parser/conv1d_inout_adjust.cc ${CMAKE_CURRENT_SOURCE_DIR}/parser/conv1d_inout_adjust.cc
${CMAKE_CURRENT_SOURCE_DIR}/parser/inputs_adjust.cc ${CMAKE_CURRENT_SOURCE_DIR}/parser/inputs_adjust.cc
${CMAKE_CURRENT_SOURCE_DIR}/parser/insert_transpose.cc
${CMAKE_CURRENT_SOURCE_DIR}/import/mindspore_importer.cc ${CMAKE_CURRENT_SOURCE_DIR}/import/mindspore_importer.cc
${CMAKE_CURRENT_SOURCE_DIR}/import/primitive_adjust.cc ${CMAKE_CURRENT_SOURCE_DIR}/import/primitive_adjust.cc
${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_adjust.cc ${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_adjust.cc
@ -73,6 +74,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/tf_gelu_fusion.cc ../optimizer/fusion/tf_gelu_fusion.cc
../optimizer/fusion/onnx_gelu_fusion.cc ../optimizer/fusion/onnx_gelu_fusion.cc
../optimizer/fusion/squeeze_fusion.cc ../optimizer/fusion/squeeze_fusion.cc
../optimizer/fusion/transpose_fusion.cc
../optimizer/fisson/eliminate_concat_split.cc ../optimizer/fisson/eliminate_concat_split.cc
../optimizer/fisson/fisson_util.cc ../optimizer/fisson/fisson_util.cc
../optimizer/fisson/iter_node_outputs.cc ../optimizer/fisson/iter_node_outputs.cc

View File

@ -67,6 +67,7 @@
#include "tools/optimizer/parallel/parallel_pass.h" #include "tools/optimizer/parallel/parallel_pass.h"
#include "include/registry/pass_registry.h" #include "include/registry/pass_registry.h"
#include "tools/optimizer/fisson/multi_conv_split_pass.h" #include "tools/optimizer/fisson/multi_conv_split_pass.h"
#include "tools/optimizer/fusion/transpose_fusion.h"
using std::string; using std::string;
namespace mindspore::lite { namespace mindspore::lite {
@ -82,6 +83,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
if (!config->trainModel) { if (!config->trainModel) {
// remove quantdtype when awaretraining // remove quantdtype when awaretraining
fusion_pm->AddPass(std::make_shared<opt::SqueezeFusion>()); fusion_pm->AddPass(std::make_shared<opt::SqueezeFusion>());
fusion_pm->AddPass(std::make_shared<opt::TransposeFusion>());
fusion_pm->AddPass(std::make_shared<opt::ReshapeReshapeFusion>()); fusion_pm->AddPass(std::make_shared<opt::ReshapeReshapeFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>()); fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
auto conv_bn_pass = std::make_shared<opt::ConvBatchNormFusion>(); auto conv_bn_pass = std::make_shared<opt::ConvBatchNormFusion>();
@ -338,7 +340,7 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
auto format_pass = std::make_shared<opt::UnifyFormatPass>(); auto format_pass = std::make_shared<opt::UnifyFormatPass>();
format_pass->Init(config->fmk, config->trainModel); format_pass->Init(config->fmk, config->trainModel);
if (!format_pass->RunOnlyForShape(old_graph)) { if (!format_pass->Run(old_graph)) {
MS_LOG(ERROR) << "Run format pass failed."; MS_LOG(ERROR) << "Run format pass failed.";
return nullptr; return nullptr;
} }

View File

@ -23,6 +23,7 @@
#include "tools/converter/import/mindir_adjust.h" #include "tools/converter/import/mindir_adjust.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "tools/common/tensor_util.h" #include "tools/common/tensor_util.h"
#include "tools/converter/parser/insert_transpose.h"
namespace mindspore::lite { namespace mindspore::lite {
namespace { namespace {
@ -202,6 +203,11 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return nullptr;
} }
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_MS, flag.trainModel);
if (!insert_transpose->Run(func_graph)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
}
if ((status = WeightFormatTransform(func_graph)) != RET_OK) { if ((status = WeightFormatTransform(func_graph)) != RET_OK) {
MS_LOG(ERROR) << "WeightFormatTransform failed."; MS_LOG(ERROR) << "WeightFormatTransform failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);

View File

@ -31,6 +31,7 @@
#include "tools/converter/quant_param_holder.h" #include "tools/converter/quant_param_holder.h"
#include "tools/converter/parser/parser_utils.h" #include "tools/converter/parser/parser_utils.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/parser/insert_transpose.h"
using mindspore::lite::converter::FmkType_CAFFE; using mindspore::lite::converter::FmkType_CAFFE;
namespace mindspore::lite { namespace mindspore::lite {
@ -103,6 +104,11 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag)
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return nullptr;
} }
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_CAFFE, false);
if (!insert_transpose->Run(res_graph_)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
}
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) { if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
MS_LOG(ERROR) << "WeightFormatTransform failed."; MS_LOG(ERROR) << "WeightFormatTransform failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);

View File

@ -0,0 +1,511 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/insert_transpose.h"
#include <queue>
#include <set>
#include <unordered_map>
#include <utility>
#include "ops/op_utils.h"
#include "src/common/common.h"
#include "src/common/utils.h"
#include "tools/common/tensor_util.h"
using mindspore::lite::NCHW_SHAPE;
namespace mindspore {
namespace lite {
namespace {
constexpr size_t kNCHWDimNumber = 4;
const std::vector<int> NH2NC = {0, 3, 1, 2};
const std::vector<int> NC2NH = {0, 2, 3, 1};
bool IsSpecialType(const CNodePtr &cnode) {
if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || opt::CheckPrimitiveType(cnode, prim::kPrimDepend) ||
opt::CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || opt::CheckPrimitiveType(cnode, opt::kPrimMakeTupleV2) ||
opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
return true;
}
return false;
}
} // namespace
void InsertTranspose::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) {
MS_ASSERT(cnode != nullptr);
auto prim_node = cnode->input(0);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_ASSERT(prim != nullptr);
auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
auto &specify_nchw_op_map = opt::GetNCHWOpMap();
if (fmk_type_ == lite::converter::FmkType_TFLITE) {
if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
return;
}
trans_info->pre_ = opt::kNHWC2NCHW;
trans_info->post_ = opt::kNCHW2NHWC;
} else if (fmk_type_ == lite::converter::FmkType_TF) {
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && opt::GetFormat(cnode) == NCHW) {
trans_info->pre_ = opt::kNCHW2NHWC;
trans_info->post_ = opt::kNHWC2NCHW;
}
if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) {
trans_info->pre_ = opt::kNHWC2NCHW;
trans_info->post_ = opt::kNCHW2NHWC;
}
} else {
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) {
if (fmk_type_ == lite::converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr &&
GetValue<int64_t>(prim->GetAttr(ops::kFormat)) == NHWC) {
return;
}
trans_info->pre_ = opt::kNCHW2NHWC;
trans_info->post_ = opt::kNHWC2NCHW;
}
}
}
AnfNodePtr InsertTranspose::GenNewInputWithoutShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm, bool before, size_t index) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
AnfNodePtr new_input = nullptr;
AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode;
std::string trans_name =
before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post";
new_input = opt::GenTransposeNode(func_graph, trans_input_node, perm, trans_name);
auto new_input_prim = GetValueNode<PrimitivePtr>(new_input->cast<CNodePtr>()->input(0));
if (perm == NC2NH) {
new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
} else if (perm == NH2NC) {
new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
}
return new_input;
}
STATUS InsertTranspose::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm,
bool before, size_t index) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
AnfNodePtr new_input = nullptr;
new_input = GenNewInputWithoutShape(func_graph, cnode, perm, before, index);
if (new_input == nullptr) {
MS_LOG(ERROR) << "generate a transpose node failed.";
return lite::RET_ERROR;
}
if (new_input == cnode->input(index) || new_input == cnode) {
return lite::RET_OK;
}
auto manager = func_graph->manager();
if (manager == nullptr) {
manager = Manage(func_graph, true);
}
MS_ASSERT(manager != nullptr);
auto tr = manager->Transact();
if (before) {
tr.SetEdge(cnode, index, new_input);
tr.Commit();
} else {
func_graph->manager()->Replace(cnode, new_input);
}
return lite::RET_OK;
}
STATUS InsertTranspose::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto prim_node = cnode->input(0);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_ASSERT(prim != nullptr);
auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
auto &specify_nchw_op_map = opt::GetNCHWOpMap();
if (specify_nhwc_op_map.find(prim->name()) == specify_nhwc_op_map.end() &&
specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
MS_LOG(ERROR) << "op don't meet nhwc condition.";
return lite::RET_ERROR;
}
std::vector<size_t> insert_index = specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()
? specify_nhwc_op_map.at(prim->name())
: specify_nchw_op_map.at(prim->name());
if (insert_index.empty()) {
if (opt::CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr &&
GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) {
insert_index.push_back(1);
} else {
for (size_t i = 1; i < cnode->size(); ++i) {
insert_index.push_back(i);
}
}
}
for (auto &index : insert_index) {
if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
STATUS InsertTranspose::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (!cnode->abstract()->isa<abstract::AbstractTuple>()) {
if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
} else {
auto node_users = func_graph->manager()->node_users()[cnode];
for (auto &node_user : node_users) {
auto post_node = node_user.first;
CNodePtr tuple_get_item = nullptr;
if (!opt::CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
if (!train_flag_) {
MS_LOG(ERROR) << "post node is invalid.";
return lite::RET_ERROR;
} else {
tuple_get_item = opt::GenTupleGetItemNode(func_graph, cnode, 0);
post_node = tuple_get_item;
func_graph->manager()->Replace(cnode, tuple_get_item);
}
}
if (func_graph->manager()->node_users()[post_node].empty()) {
continue;
}
auto post_cnode = post_node->cast<CNodePtr>();
if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
if (tuple_get_item != nullptr) {
func_graph->manager()->Replace(tuple_get_item, tuple_get_item->input(1));
}
}
}
return lite::RET_OK;
}
STATUS InsertTranspose::HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
if (fmk_type_ == lite::converter::FmkType_TF || fmk_type_ == lite::converter::FmkType_TFLITE) {
return lite::RET_NO_CHANGE;
}
for (size_t i = 1; i < cnode->size(); ++i) {
auto node = cnode->input(i);
if (!utils::isa<ParameterPtr>(node)) {
continue;
}
auto param_node = node->cast<ParameterPtr>();
if (param_node->has_default()) {
continue;
}
auto abstract_base = param_node->abstract();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
return lite::RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
return lite::RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
return lite::RET_ERROR;
}
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
if (shape_vector.size() != 4) {
continue;
}
if (func_graph->get_inputs().size() == 1 && fmk_type_ == lite::converter::FmkType_ONNX && shape_vector[3] == 3 &&
shape_vector[1] == -1) {
continue;
}
std::vector<int64_t> new_dims = {shape_vector[NCHW_SHAPE::NCHW_N], shape_vector[NCHW_SHAPE::NCHW_H],
shape_vector[NCHW_SHAPE::NCHW_W], shape_vector[NCHW_SHAPE::NCHW_C]};
abstract_tensor->set_shape(std::make_shared<abstract::Shape>(new_dims));
auto trans_cnode = opt::GenTransposeNode(func_graph, param_node, NH2NC, param_node->fullname_with_scope() + "_pre");
auto new_input_prim = GetValueNode<PrimitivePtr>(trans_cnode->cast<CNodePtr>()->input(0));
new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
if (trans_cnode == nullptr) {
MS_LOG(ERROR) << "generate a transpose node failed.";
return lite::RET_ERROR;
}
func_graph->manager()->Replace(param_node, trans_cnode);
}
return lite::RET_OK;
}
STATUS InsertTranspose::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
opt::TransTypePair trans_info;
GetTransNodeFormatType(cnode, &trans_info);
if (trans_info.pre_ == opt::kNONE || trans_info.post_ == opt::kNONE) {
return lite::RET_NO_CHANGE;
}
auto before_perm = trans_info.pre_ == opt::kNHWC2NCHW ? NH2NC : NC2NH;
auto after_perm = trans_info.post_ == opt::kNCHW2NHWC ? NC2NH : NH2NC;
if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam) || opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
return RET_OK;
}
if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
return lite::RET_OK;
}
void InsertTranspose::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto sub_inputs = sub_graph->get_inputs();
sub_inputs_map_[sub_graph] = sub_inputs;
for (auto &node : sub_inputs) {
auto param_node = node->cast<ParameterPtr>();
MS_ASSERT(param_node != nullptr);
auto node_name = node->fullname_with_scope();
auto last_underline = node_name.find_last_of("_");
node_name = node_name.substr(0, last_underline);
last_underline = node_name.find_last_of("_");
auto index = std::stoi(node_name.substr(last_underline + 1)) + 3;
param_node->set_abstract(opt::GetCNodeInputAbstract(cnode, index)->Clone());
if (utils::isa<CNodePtr>(cnode->input(index))) {
ShapeVector shape_vec = {-1};
auto out_cnode = cnode->input(index)->cast<CNodePtr>();
MS_ASSERT(trans_cnode != nullptr);
auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0));
if (out_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(opt::kInferDone))) {
param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vec));
}
} else {
lite::DataInfo data_info;
if (utils::isa<ParameterPtr>(cnode->input(index))) {
if (cnode->input(index)->cast<ParameterPtr>()->has_default()) {
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
}
continue;
}
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info);
if (status != lite::RET_OK) {
continue;
}
ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
if (data_info.data_.empty()) {
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec));
} else {
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec,
data_info.data_.data(), data_info.data_.size()));
}
}
}
}
void InsertTranspose::ResetSubGraphInput() {
for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
auto &sub_graph = iter->first;
auto &sub_inputs = iter->second;
auto manager = sub_graph->manager();
MS_ASSERT(manager != nullptr);
for (auto &sub_input : sub_inputs) {
auto param_node = sub_graph->add_parameter();
MS_ASSERT(param_node != nullptr);
param_node->set_abstract(sub_input->abstract()->Clone());
param_node->set_name(sub_input->fullname_with_scope());
manager->Replace(sub_input, param_node);
auto sub_param_input = sub_input->cast<ParameterPtr>();
MS_ASSERT(sub_param_input != nullptr);
sub_param_input->set_default_param(nullptr);
}
}
}
void InsertTranspose::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto return_node = sub_graph->get_return();
auto origin_input = return_node->inputs();
lite::RemoveIfDepend(return_node);
lite::RemoveIfMakeTuple(return_node);
for (size_t i = 1; i < return_node->size(); ++i) {
if (!opt::CheckPrimitiveType(return_node->input(i), prim::kPrimTranspose)) {
continue;
}
auto node_name = return_node->input(i)->fullname_with_scope();
if (node_name.substr(node_name.size() - 5) != "_post") {
continue;
}
auto trans_cnode = return_node->input(i)->cast<CNodePtr>();
MS_ASSERT(trans_cnode != nullptr);
auto trans_input = trans_cnode->input(1);
auto trans_input_name = trans_input->fullname_with_scope();
if (utils::isa<ParameterPtr>(trans_input)) {
trans_input->cast<ParameterPtr>()->set_name(node_name);
} else if (utils::isa<CNodePtr>(trans_input)) {
trans_input->cast<CNodePtr>()->set_fullname_with_scope(node_name);
}
trans_input_name = trans_input_name.substr(0, trans_input_name.find_last_of("_")) + "_cnode";
trans_cnode->set_fullname_with_scope(trans_input_name);
}
return_node->set_inputs(origin_input);
}
void InsertTranspose::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
auto return_node = sub_graph->get_return();
auto origin_inputs = return_node->inputs();
lite::RemoveIfDepend(return_node);
lite::RemoveIfMakeTuple(return_node);
AbstractBasePtrList abstract_list;
bool infer_done = true;
for (size_t i = 1; i < return_node->size(); ++i) {
auto abstract_base = opt::GetCNodeInputAbstract(return_node, i);
MS_ASSERT(abstract_base != nullptr);
abstract_list.emplace_back(abstract_base->Clone());
auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
MS_ASSERT(abstract_tensor != nullptr);
auto shape_ptr = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape());
MS_ASSERT(shape_ptr != nullptr);
auto shape = shape_ptr->shape();
if (std::find(shape.begin(), shape.end(), -1) != shape.end()) {
infer_done = false;
}
if (utils::isa<CNodePtr>(return_node->input(i))) {
auto input_cnode = return_node->input(i)->cast<CNodePtr>();
if (opt::CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
input_cnode = input_cnode->input(1)->cast<CNodePtr>();
}
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
if (input_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(opt::kInferDone))) {
infer_done = false;
}
}
}
return_node->set_inputs(origin_inputs);
if (utils::isa<abstract::AbstractTuplePtr>(cnode->abstract())) {
cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
} else {
if (abstract_list.size() != 1) {
MS_LOG(ERROR) << "cnode output is invalid.";
}
cnode->set_abstract(abstract_list.front());
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
prim->AddAttr(opt::kInferDone, MakeValue<bool>(infer_done));
}
bool InsertTranspose::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) {
MS_ASSERT(func_graph != nullptr);
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return false;
}
auto node_list = TopoSort(func_graph->get_return());
int status;
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (IsSpecialType(cnode)) {
continue;
}
if (main_graph) {
status = HandleGraphInput(func_graph, cnode);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
return false;
}
}
if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
(void)BasicProcess(sub_func_graph, false);
SetSubGraphOutput(cnode, sub_func_graph);
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
(void)BasicProcess(sub_func_graph, false);
SetSubGraphOutput(cnode, sub_func_graph);
SetSubGraphAbstract(cnode, sub_func_graph);
continue;
}
status = HandleGraphNode(func_graph, cnode);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
return false;
}
}
return true;
}
bool InsertTranspose::ResetFuncGraph(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return false;
}
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->GetAttr(opt::kInferDone) != nullptr) {
prim->EraseAttr(opt::kInferDone);
}
if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
return false;
}
(void)ResetFuncGraph(sub_func_graph);
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
if (sub_func_graph == nullptr) {
return false;
}
(void)ResetFuncGraph(sub_func_graph);
}
}
return true;
}
bool InsertTranspose::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
auto prim = GetValueNode<PrimitivePtr>(node);
if (prim == nullptr) {
continue;
}
}
// insert transpose for some ops whose format must be NHWC, which is depend on framework.
// In this process, tranpose can be fused, which the original graph may not be able to restored.
if (!BasicProcess(func_graph, true)) {
MS_LOG(ERROR) << "run framework transpose unify failed.";
return false;
}
ResetSubGraphInput();
return true;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,62 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_
#include <vector>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include "utils/utils.h"
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/format_utils.h"
#include "tools/anf_exporter/fetch_content.h"
using mindspore::lite::converter::FmkType;
namespace mindspore {
namespace lite {
class InsertTranspose {
public:
InsertTranspose(FmkType fmk_type, bool train_flag) : fmk_type_(fmk_type), train_flag_(train_flag) {}
~InsertTranspose() = default;
bool Run(const FuncGraphPtr &func_graph);
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before,
size_t index = 0);
private:
AnfNodePtr GenNewInputWithoutShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm, bool before, size_t index);
bool ResetFuncGraph(const FuncGraphPtr &func_graph);
bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph);
void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info);
STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void ResetSubGraphInput();
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
FmkType fmk_type_{lite::converter::FmkType_MS};
bool train_flag_{false};
std::unordered_map<FuncGraphPtr, std::vector<AnfNodePtr>> sub_inputs_map_;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_

View File

@ -35,6 +35,7 @@
#include "tools/converter/parser/onnx/onnx_pad_adjust.h" #include "tools/converter/parser/onnx/onnx_pad_adjust.h"
#include "tools/converter/parser/parser_utils.h" #include "tools/converter/parser/parser_utils.h"
#include "ops/transpose.h" #include "ops/transpose.h"
#include "tools/converter/parser/insert_transpose.h"
using mindspore::lite::converter::FmkType_ONNX; using mindspore::lite::converter::FmkType_ONNX;
namespace mindspore { namespace mindspore {
@ -90,6 +91,11 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return nullptr;
} }
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_ONNX, false);
if (!insert_transpose->Run(res_graph_)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
}
if ((status = WeightFormatTransform(all_func_graphs)) != RET_OK) { if ((status = WeightFormatTransform(all_func_graphs)) != RET_OK) {
MS_LOG(ERROR) << "WeightFormatTransform failed."; MS_LOG(ERROR) << "WeightFormatTransform failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);

View File

@ -33,6 +33,7 @@
#include "tools/converter/parser/tf/functionalize_control_op_pass.h" #include "tools/converter/parser/tf/functionalize_control_op_pass.h"
#include "tools/converter/parser/parser_utils.h" #include "tools/converter/parser/parser_utils.h"
#include "tools/common/tensor_util.h" #include "tools/common/tensor_util.h"
#include "tools/converter/parser/insert_transpose.h"
using mindspore::lite::converter::FmkType_TF; using mindspore::lite::converter::FmkType_TF;
namespace mindspore { namespace mindspore {
@ -575,6 +576,11 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return nullptr;
} }
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_TF, false);
if (!insert_transpose->Run(res_graph_)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
}
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) { if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
MS_LOG(ERROR) << "WeightFormatTransform failed."; MS_LOG(ERROR) << "WeightFormatTransform failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);

View File

@ -30,6 +30,7 @@
#include "tools/converter/converter_flags.h" #include "tools/converter/converter_flags.h"
#include "tools/converter/parser/tflite/tflite_inputs_adjust.h" #include "tools/converter/parser/tflite/tflite_inputs_adjust.h"
#include "tools/converter/parser/parser_utils.h" #include "tools/converter/parser/parser_utils.h"
#include "tools/converter/parser/insert_transpose.h"
using mindspore::lite::converter::FmkType_TFLITE; using mindspore::lite::converter::FmkType_TFLITE;
namespace mindspore::lite { namespace mindspore::lite {
@ -104,6 +105,11 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return nullptr;
} }
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_TFLITE, false);
if (!insert_transpose->Run(res_graph_)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
}
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) { if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
MS_LOG(ERROR) << "WeightFormatTransform failed."; MS_LOG(ERROR) << "WeightFormatTransform failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);

View File

@ -26,7 +26,6 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
constexpr auto kInferDone = "infer_done"; constexpr auto kInferDone = "infer_done";
constexpr auto kTransDone = "trans_done";
enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE }; enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE };
struct TransTypePair { struct TransTypePair {
FormatTransNodeType pre_; FormatTransNodeType pre_;

View File

@ -0,0 +1,176 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/transpose_fusion.h"
#include <unordered_map>
#include <memory>
#include <vector>
#include "tools/converter/quant_param_holder.h"
#include "mindspore/core/ops/transpose.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::opt {
namespace {
const std::vector<int> NH2NC = {0, 3, 1, 2};
const std::vector<int> NC2NH = {0, 2, 3, 1};
} // namespace
bool IsBNCNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
auto anf_node = utils::cast<AnfNodePtr>(n);
return CheckPrimitiveType(anf_node, prim::kPrimBatchNorm) ||
CheckPrimitiveType(anf_node, prim::kPrimFusedBatchNorm);
}
return false;
}
VectorRef TransposeFusion::DefineBNPattern() const {
auto transpose_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
auto conv_var = std::make_shared<CondVar>(IsConvNode);
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
VectorRef transpose_conv_ref = VectorRef({transpose_var, conv_var, transpose_param});
auto bn_var = std::make_shared<CondVar>(IsBNCNode);
auto bn_mean_var = std::make_shared<CondVar>(IsParamNode);
auto bn_variable_var = std::make_shared<CondVar>(IsParamNode);
auto bn_other_var = std::make_shared<SeqVar>();
VectorRef bn_ref = VectorRef({bn_var, transpose_conv_ref, bn_mean_var, bn_variable_var, bn_other_var});
return bn_ref;
}
VectorRef TransposeFusion::DefineActivationscalePattern() const {
auto transpose_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
auto conv_var = std::make_shared<CondVar>(IsConvNode);
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
VectorRef transpose_conv_ref = VectorRef({transpose_var, conv_var, transpose_param});
auto scale_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimScaleFusion>);
auto scale_var_1 = std::make_shared<CondVar>(IsParamNode);
auto scale_var_2 = std::make_shared<SeqVar>();
VectorRef sclae_ref = VectorRef({scale_var, transpose_conv_ref, scale_var_1, scale_var_2});
return sclae_ref;
}
VectorRef TransposeFusion::DefineActivationPattern() const {
auto transpose_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
auto conv_var = std::make_shared<CondVar>(IsConvNode);
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
VectorRef transpose_conv_ref = VectorRef({transpose_var, conv_var, transpose_param});
auto act_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimActivation>);
VectorRef act_ref = VectorRef({act_var, transpose_conv_ref});
return act_ref;
}
VectorRef TransposeFusion::DefineBiasAddPattern() const {
auto transpose_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
auto conv_var = std::make_shared<CondVar>(IsConvNode);
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
VectorRef transpose_conv_ref = VectorRef({transpose_var, conv_var, transpose_param});
auto bias_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBiasAdd>);
auto bias_param = std::make_shared<CondVar>(IsParamNode);
VectorRef act_ref = VectorRef({bias_var, transpose_conv_ref, bias_param});
return act_ref;
}
VectorRef TransposeFusion::DefineTransTransPattern() const {
auto transpose_var_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
auto transpose_var_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
VectorRef trans_trans_ref = VectorRef({transpose_var_2, transpose_var_1, transpose_param});
return trans_trans_ref;
}
std::unordered_map<std::string, VectorRef> TransposeFusion::DefinePatterns() const {
std::unordered_map<std::string, VectorRef> patterns;
patterns["BNPatternName"] = DefineBNPattern();
patterns["ActivationPatternName"] = DefineActivationPattern();
patterns["BiasAddPatternName"] = DefineBiasAddPattern();
patterns["ScalePatternName"] = DefineActivationscalePattern();
patterns["TransTransPatternName"] = DefineTransTransPattern();
return patterns;
}
CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const AnfNodePtr &perm,
const std::string &cnode_name) {
MS_ASSERT(func_graph != nullptr && input_node != nullptr);
auto trans_prim = std::make_shared<ops::Transpose>();
MS_ASSERT(trans_prim != nullptr);
auto cnode = func_graph->NewCNode(trans_prim, {input_node, perm});
MS_ASSERT(cnode != nullptr);
cnode->set_fullname_with_scope(cnode_name);
auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(2, 1);
auto trans_insert_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
trans_insert_prim->AddAttr("quant_params", quant_params_holder);
return cnode;
}
AnfNodePtr TransposeFusion::TransTransFusion(const mindspore::FuncGraphPtr &func_graph,
const mindspore::AnfNodePtr &node) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
auto trans_cnode_2 = node->cast<CNodePtr>();
if (!CheckPrimitiveType(trans_cnode_2, prim::kPrimTranspose) ||
!CheckPrimitiveType(trans_cnode_2->input(1), prim::kPrimTranspose)) {
return nullptr;
}
std::vector<int> post_perm;
if (GetTransposePerm(trans_cnode_2, &post_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "get tanspose perm failed.";
return nullptr;
}
std::vector<int> pre_perm;
auto pre_node = trans_cnode_2->input(1);
auto pre_cnode = pre_node->cast<CNodePtr>();
if (pre_cnode == nullptr) {
return nullptr;
}
if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "get tanspose perm failed.";
return nullptr;
}
if ((pre_perm == NH2NC && post_perm == NC2NH) || (pre_perm == NC2NH && post_perm == NH2NC)) {
return pre_cnode->input(1);
}
return nullptr;
}
AnfNodePtr TransposeFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
if (pattern_name == "TransTransPatternName") {
return TransTransFusion(func_graph, node);
}
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
if (CheckIfCNodeIsNull(node->cast<CNodePtr>()) != lite::RET_OK) {
return nullptr;
}
auto any_cnode = node->cast<CNodePtr>();
const auto transpose_node = any_cnode->input(1);
if (CheckIfCNodeIsNull(transpose_node->cast<CNodePtr>()) != lite::RET_OK) {
return nullptr;
}
const CNodePtr &transpose_cnode = transpose_node->cast<CNodePtr>();
auto perm_node = transpose_cnode->input(2);
auto trans_post_node = GenTransposeNode(func_graph, any_cnode, perm_node, any_cnode->fullname_with_scope() + "_post");
auto tr = func_graph->manager()->Transact();
tr.SetEdge(any_cnode, 1, transpose_cnode->input(1));
tr.Commit();
return trans_post_node;
}
} // namespace mindspore::opt

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_TRANSPOSE_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_TRANSPOSE_FUSION_H_
#include <string>
#include <unordered_map>
#include "tools/optimizer/graph/unify_format_pass.h"
#include "backend/optimizer/common/optimizer.h"
#include "schema/inner/model_generated.h"
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
namespace mindspore {
namespace opt {
class TransposeFusion : public MultiplePatternProcessPass {
public:
explicit TransposeFusion(const std::string &name = "transpose_fusion", bool multigraph = true)
: MultiplePatternProcessPass(name, multigraph) {}
~TransposeFusion() override = default;
std::unordered_map<std::string, VectorRef> DefinePatterns() const override;
VectorRef DefineBNPattern() const;
VectorRef DefineActivationPattern() const;
VectorRef DefineActivationscalePattern() const;
VectorRef DefineTransTransPattern() const;
VectorRef DefineBiasAddPattern() const;
AnfNodePtr TransTransFusion(const mindspore::FuncGraphPtr &func_graph, const mindspore::AnfNodePtr &node) const;
AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &,
const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_TRANSPOSE_FUSION_H_

View File

@ -165,6 +165,8 @@ STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
} }
if (ret == lite::RET_OK || ret == lite::RET_INFER_INVALID) { if (ret == lite::RET_OK || ret == lite::RET_INFER_INVALID) {
auto set_status = SetCNodeAbstract(cnode, outputs, ret); auto set_status = SetCNodeAbstract(cnode, outputs, ret);
auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
cnode_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(inputs[0]->format()));
if (set_status != lite::RET_OK) { if (set_status != lite::RET_OK) {
MS_LOG(ERROR) << "set CNode abstract failed: " << cnode->fullname_with_scope(); MS_LOG(ERROR) << "set CNode abstract failed: " << cnode->fullname_with_scope();
FreeTensors(&inputs); FreeTensors(&inputs);
@ -336,6 +338,7 @@ STATUS NodeInferShape::GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite:
tensor = GetCNodeTensorListVarInput(data_info); tensor = GetCNodeTensorListVarInput(data_info);
} else { } else {
tensor = new (std::nothrow) lite::Tensor(TypeId(data_info.data_type_), data_info.shape_); tensor = new (std::nothrow) lite::Tensor(TypeId(data_info.data_type_), data_info.shape_);
tensor->set_format((Format)(data_info.format_));
} }
if (tensor == nullptr) { if (tensor == nullptr) {
MS_LOG(ERROR) << "new a lite tensor failed"; MS_LOG(ERROR) << "new a lite tensor failed";

View File

@ -20,6 +20,7 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <string> #include <string>
#include <map>
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "src/tensor.h" #include "src/tensor.h"
#include "tools/anf_exporter/fetch_content.h" #include "tools/anf_exporter/fetch_content.h"
@ -41,9 +42,9 @@ class NodeInferShape {
bool JudgeOpSupportInfer(const CNodePtr &cnode); bool JudgeOpSupportInfer(const CNodePtr &cnode);
std::vector<int> GetInputShape(const CNodePtr &cnode, size_t index); std::vector<int> GetInputShape(const CNodePtr &cnode, size_t index);
std::vector<int> GetIntVecInput(const CNodePtr &cnode, size_t index); std::vector<int> GetIntVecInput(const CNodePtr &cnode, size_t index);
STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *inputs);
private: private:
STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *inputs);
STATUS GetCNodeConstInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *const_ms_inputs); STATUS GetCNodeConstInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *const_ms_inputs);
STATUS GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *var_ms_inputs); STATUS GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *var_ms_inputs);
lite::Tensor *GetCNodeTensorListVarInput(const lite::DataInfo &data_info); lite::Tensor *GetCNodeTensorListVarInput(const lite::DataInfo &data_info);

View File

@ -32,7 +32,13 @@ constexpr size_t kInputTripleNum = 3;
int ProcessInputIsMonad(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { int ProcessInputIsMonad(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr); MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto first_input = cnode->input(1); auto first_input = cnode->input(1);
if (CheckPrimitiveType(first_input, prim::kPrimTranspose)) {
first_input = cnode->input(1)->cast<CNodePtr>()->input(1);
}
auto second_input = cnode->input(2); auto second_input = cnode->input(2);
if (CheckPrimitiveType(second_input, prim::kPrimTranspose)) {
second_input = cnode->input(2)->cast<CNodePtr>()->input(1);
}
AnfNodePtr must_monad = nullptr; AnfNodePtr must_monad = nullptr;
AnfNodePtr not_must_monad = nullptr; AnfNodePtr not_must_monad = nullptr;
if (utils::isa<ValueNode>(first_input)) { if (utils::isa<ValueNode>(first_input)) {
@ -72,6 +78,12 @@ int ProcessDependencyWithTwoNodes(const FuncGraphPtr &func_graph, const CNodePtr
pre_node = cnode->input(2); pre_node = cnode->input(2);
post_node = cnode->input(1); post_node = cnode->input(1);
} }
if (CheckPrimitiveType(pre_node, prim::kPrimTranspose)) {
pre_node = cnode->input(1)->cast<CNodePtr>()->input(1);
}
if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) {
post_node = cnode->input(2)->cast<CNodePtr>()->input(1);
}
auto manager = func_graph->manager(); auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr); MS_ASSERT(manager != nullptr);
auto node_users = manager->node_users()[pre_node]; auto node_users = manager->node_users()[pre_node];
@ -102,6 +114,10 @@ int ProcessInputHaveDependency(const FuncGraphPtr &func_graph, const CNodePtr &c
auto make_tuple_prim = NewValueNode(std::make_shared<lite::MakeTuple>()); auto make_tuple_prim = NewValueNode(std::make_shared<lite::MakeTuple>());
auto manager = func_graph->manager(); auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr); MS_ASSERT(manager != nullptr);
if (CheckPrimitiveType(cnode->input(0), prim::kPrimTranspose)) {
manager->Replace(cnode->input(0)->cast<CNodePtr>()->input(0), make_tuple_prim);
return RET_OK;
}
manager->Replace(cnode->input(0), make_tuple_prim); manager->Replace(cnode->input(0), make_tuple_prim);
return lite::RET_OK; return lite::RET_OK;
} }

View File

@ -335,6 +335,9 @@ STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodeP
} }
} }
auto manager = func_graph->manager(); auto manager = func_graph->manager();
if (manager == nullptr) {
manager = Manage(func_graph, true);
}
MS_ASSERT(manager != nullptr); MS_ASSERT(manager != nullptr);
auto tr = manager->Transact(); auto tr = manager->Transact();
if (before) { if (before) {
@ -428,6 +431,7 @@ STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const
} }
} }
status = node_infer_shape_.InferShape(cnode); status = node_infer_shape_.InferShape(cnode);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed."; MS_LOG(ERROR) << "infer shape failed.";
return lite::RET_ERROR; return lite::RET_ERROR;
@ -523,47 +527,8 @@ STATUS UnifyFormatPass::HandleGraphInput(const FuncGraphPtr &func_graph, const C
MS_LOG(ERROR) << "infer shape failed."; MS_LOG(ERROR) << "infer shape failed.";
return lite::RET_ERROR; return lite::RET_ERROR;
} }
func_graph->manager()->Replace(param_node, trans_cnode);
}
return lite::RET_OK;
}
STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { func_graph->manager()->Replace(param_node, trans_cnode);
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto prim_node = cnode->input(0);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_ASSERT(prim != nullptr);
if (prim->GetAttr(kTransDone) != nullptr && GetValue<bool>(prim->GetAttr(kTransDone))) {
return lite::RET_OK;
}
prim->AddAttr(kTransDone, MakeValue<bool>(true));
TransTypePair trans_info;
GetTransNodeFormatType(cnode, &trans_info);
if (trans_info.pre_ == kNONE || trans_info.post_ == kNONE) {
if (!need_reset_ && TransTransFusion(func_graph, cnode)) {
return lite::RET_OK;
}
auto status = node_infer_shape_.InferShape(cnode);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed: " << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
return lite::RET_NO_CHANGE;
}
auto before_perm = trans_info.pre_ == kNHWC2NCHW ? NH2NC : NC2NH;
auto after_perm = trans_info.post_ == kNCHW2NHWC ? NC2NH : NH2NC;
if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
auto status = node_infer_shape_.InferShape(cnode);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed.";
return lite::RET_ERROR;
}
if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
return lite::RET_ERROR;
} }
return lite::RET_OK; return lite::RET_OK;
} }
@ -763,58 +728,6 @@ void UnifyFormatPass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraph
prim->AddAttr(kInferDone, MakeValue<bool>(infer_done)); prim->AddAttr(kInferDone, MakeValue<bool>(infer_done));
} }
bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) {
MS_ASSERT(func_graph != nullptr);
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return false;
}
auto node_list = TopoSort(func_graph->get_return());
int status;
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (IsSpecialType(cnode)) {
continue;
}
if (main_graph && !need_reset_) {
status = HandleGraphInput(func_graph, cnode);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
return false;
}
}
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
(void)BasicProcess(sub_func_graph, false);
SetSubGraphOutput(cnode, sub_func_graph);
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
(void)BasicProcess(sub_func_graph, false);
SetSubGraphOutput(cnode, sub_func_graph);
SetSubGraphAbstract(cnode, sub_func_graph);
continue;
}
status = HandleGraphNode(func_graph, cnode);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
return false;
}
}
return true;
}
bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) { bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr); MS_ASSERT(func_graph != nullptr);
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name")); auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
@ -935,33 +848,6 @@ bool UnifyFormatPass::ResetFuncGraph(const FuncGraphPtr &func_graph) {
if (prim->GetAttr(kInferDone) != nullptr) { if (prim->GetAttr(kInferDone) != nullptr) {
prim->EraseAttr(kInferDone); prim->EraseAttr(kInferDone);
} }
if (prim->GetAttr(kTransDone) != nullptr) {
prim->EraseAttr(kTransDone);
}
if (pre_insert_trans_.find(cnode) != pre_insert_trans_.end()) {
manager->Replace(node, cnode->input(1));
}
if (post_insert_trans_.find(cnode) != post_insert_trans_.end()) {
auto cnode_abstract = cnode->abstract();
if (!utils::isa<abstract::AbstractTensorPtr>(cnode_abstract)) {
MS_LOG(ERROR) << "abstract is not abstract tensor.";
return false;
}
auto cnode_abstract_tensor = cnode_abstract->cast<abstract::AbstractTensorPtr>();
if (!utils::isa<abstract::ShapePtr>(cnode_abstract_tensor->BuildShape())) {
MS_LOG(ERROR) << "shape of abstract tensor should be ShapePtr.";
return false;
}
auto shape_ptr = utils::cast<abstract::ShapePtr>(cnode_abstract_tensor->BuildShape());
auto input_abstract = GetCNodeInputAbstract(cnode, 1);
if (!utils::isa<abstract::AbstractTensorPtr>(input_abstract)) {
MS_LOG(ERROR) << "abstract is not abstract tensor.";
return false;
}
auto input_abstract_tensor = input_abstract->cast<abstract::AbstractTensorPtr>();
input_abstract_tensor->set_shape(shape_ptr);
manager->Replace(node, cnode->input(1));
}
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) { if (sub_func_graph == nullptr) {
@ -1025,18 +911,140 @@ bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "exist op cannot support infer shape."; MS_LOG(ERROR) << "exist op cannot support infer shape.";
return false; return false;
} }
need_reset_ = true; if (!RunNodeInferShape(func_graph)) {
// insert transpose for some ops whose format must be NHWC, which is depend on framework. MS_LOG(ERROR) << "RunNodeInferShape failed.";
// In this process, transpose op cannot be fused to restore the original graph.
if (!BasicProcess(func_graph, true)) {
MS_LOG(ERROR) << "run framework transpose unify failed.";
return false; return false;
} }
ResetSubGraphInput(); ResetSubGraphInput();
// delete insert transpose op and update op output shape. ResetFuncGraph(func_graph);
if (!ResetFuncGraph(func_graph)) { return true;
MS_LOG(ERROR) << "reset func_graph failed."; }
return false;
bool UnifyFormatPass::RunNodeInferShape(const FuncGraphPtr &func_graph) {
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (IsSpecialType(cnode)) {
continue;
}
if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
if (!RunNodeInferShape(sub_func_graph)) {
MS_LOG(ERROR) << "subgraph infer shape failed.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
return false;
}
SetSubGraphOutput(cnode, sub_func_graph);
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
if (!RunNodeInferShape(sub_func_graph)) {
MS_LOG(ERROR) << "subgraph infer shape failed.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
return false;
}
SetSubGraphOutput(cnode, sub_func_graph);
SetSubGraphAbstract(cnode, sub_func_graph);
continue;
}
auto status = node_infer_shape_.InferShape(cnode);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed." << cnode->fullname_with_scope();
return false;
}
}
return true;
}
bool UnifyFormatPass::RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
auto prim_node = cnode->input(0);
auto prim = GetValueNode<PrimitivePtr>(prim_node);
auto &nchw_op = GetNCHWOpMap();
if (!utils::isa<CNodePtr>(cnode->input(1))) {
return true;
}
if (utils::isa<CNodePtr>(cnode->input(1))) {
auto format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
if (nchw_op.find(prim->name()) != nchw_op.end() && format != NCHW) {
InsertPreTransNode(func_graph, cnode, {0, 3, 1, 2});
InsertPostTransNode(func_graph, cnode, {0, 2, 3, 1});
}
}
{
if (CheckPrimitiveType(cnode, prim::kPrimTranspose)) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto manager = func_graph->manager();
if (manager == nullptr) {
manager = Manage(func_graph, true);
}
auto shape = node_infer_shape_.GetInputShape(cnode, 1);
std::vector<int> perm;
auto status = GetTransposePerm(cnode, &perm);
if (status != RET_OK) {
return false;
}
if (!shape.empty() && shape.size() != perm.size()) {
manager->Replace(cnode, cnode->input(1));
}
}
}
return true;
}
bool UnifyFormatPass::DoFixFormat(const FuncGraphPtr &func_graph) {
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (IsSpecialType(cnode)) {
continue;
}
if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
if (!DoFixFormat(sub_func_graph)) {
MS_LOG(ERROR) << "subgraph infer shape failed.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
return false;
}
SetSubGraphOutput(cnode, sub_func_graph);
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
SetSubGraphInput(cnode, sub_func_graph);
if (!DoFixFormat(sub_func_graph)) {
MS_LOG(ERROR) << "subgraph infer shape failed.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
return false;
}
SetSubGraphOutput(cnode, sub_func_graph);
SetSubGraphAbstract(cnode, sub_func_graph);
continue;
}
if (!RunDoFixFormat(func_graph, cnode)) {
return false;
}
} }
return true; return true;
} }
@ -1049,22 +1057,23 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) {
if (prim == nullptr) { if (prim == nullptr) {
continue; continue;
} }
if (prim->GetAttr(kTransDone) != nullptr) {
return true;
}
} }
if (!JudgeAllOpsCanInfer(func_graph)) { if (!JudgeAllOpsCanInfer(func_graph)) {
MS_LOG(ERROR) << "exist op cannot support infer shape."; MS_LOG(ERROR) << "exist op cannot support infer shape.";
return false; return false;
} }
// insert transpose for some ops whose format must be NHWC, which is depend on framework. if (!RunNodeInferShape(func_graph)) {
// In this process, tranpose can be fused, which the original graph may not be able to restored. MS_LOG(ERROR) << "infer shape failed.";
if (!BasicProcess(func_graph, true)) {
MS_LOG(ERROR) << "run framework transpose unify failed.";
return false; return false;
} }
ResetSubGraphInput(); ResetSubGraphInput();
// if input format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op.
if (!DoFixFormat(func_graph)) {
MS_LOG(ERROR) << "DoFixFormat failed.";
return false;
}
ResetSubGraphInput();
if (!DecreaseTransposeForSingleOp(func_graph)) { if (!DecreaseTransposeForSingleOp(func_graph)) {
MS_LOG(ERROR) << "run local trans insert optimizer failed."; MS_LOG(ERROR) << "run local trans insert optimizer failed.";
return false; return false;

View File

@ -45,23 +45,25 @@ class UnifyFormatPass : public Pass {
bool RunOnlyForShape(const FuncGraphPtr &func_graph); bool RunOnlyForShape(const FuncGraphPtr &func_graph);
private: private:
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before,
size_t index = 0);
bool RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
bool DoFixFormat(const FuncGraphPtr &func_graph);
bool RunNodeInferShape(const FuncGraphPtr &func_graph);
bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph); bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph);
bool ResetFuncGraph(const FuncGraphPtr &func_graph); bool ResetFuncGraph(const FuncGraphPtr &func_graph);
bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph);
bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph); bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph);
bool DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph); bool DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph);
bool TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode); bool TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode); STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before,
size_t index = 0);
void GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info); void GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info);
STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
std::set<CNodePtr> *visit_transposes); std::set<CNodePtr> *visit_transposes);
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info); STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info);
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
void ResetSubGraphInput(); void ResetSubGraphInput();
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);