!20196 [LITE] format
Merge pull request !20196 from yefeng/130-format-nnie
This commit is contained in:
commit
c0179f61d8
|
@ -27,20 +27,6 @@ bool CheckPermTransFormat(const int *perm, const int *perm_transformat, const si
|
||||||
}
|
}
|
||||||
|
|
||||||
int SetOutputShape(int perms_num, const TensorC *input, TensorC *output, int *perm, size_t perm_size, int *out_shape) {
|
int SetOutputShape(int perms_num, const TensorC *input, TensorC *output, int *perm, size_t perm_size, int *out_shape) {
|
||||||
if (perms_num == 4) {
|
|
||||||
const int nchw2nhwc[4] = {0, 2, 3, 1};
|
|
||||||
const int nhwc2nchw[4] = {0, 3, 1, 2};
|
|
||||||
const int trans3d[3] = {0, 2, 1};
|
|
||||||
if (input->format_ == Format_NCHW && CheckPermTransFormat(perm, nchw2nhwc, perms_num)) {
|
|
||||||
output->format_ = Format_NHWC;
|
|
||||||
} else if (input->format_ == Format_NHWC && CheckPermTransFormat(perm, nhwc2nchw, perms_num)) {
|
|
||||||
output->format_ = Format_NCHW;
|
|
||||||
}
|
|
||||||
// though the perm is 4d in default, the input can be a 3d tensor. The op implementation should be adapted to this.
|
|
||||||
if (input->shape_size_ == 3) {
|
|
||||||
ShapeSet(perm, &perm_size, trans3d, 3);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// set output shape
|
// set output shape
|
||||||
size_t in_shape_size = input->shape_size_;
|
size_t in_shape_size = input->shape_size_;
|
||||||
output->shape_size_ = in_shape_size;
|
output->shape_size_ = in_shape_size;
|
||||||
|
@ -76,13 +62,6 @@ int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
|
||||||
TensorC *output = outputs[0];
|
TensorC *output = outputs[0];
|
||||||
|
|
||||||
SetDataTypeFormat(output, input);
|
SetDataTypeFormat(output, input);
|
||||||
if (parameter->quant_type_ == QuantType_QUANT_WEIGHT) {
|
|
||||||
output->data_type_ = kNumberTypeFloat32;
|
|
||||||
}
|
|
||||||
if (!InferFlag(inputs, inputs_size)) {
|
|
||||||
return NNACL_INFER_INVALID;
|
|
||||||
}
|
|
||||||
|
|
||||||
const TensorC *perm_tensor = inputs[1];
|
const TensorC *perm_tensor = inputs[1];
|
||||||
const int32_t *perm_data = (int32_t *)perm_tensor->data_;
|
const int32_t *perm_data = (int32_t *)perm_tensor->data_;
|
||||||
const size_t perms_num = (size_t)perm_tensor->shape_[0];
|
const size_t perms_num = (size_t)perm_tensor->shape_[0];
|
||||||
|
@ -97,6 +76,28 @@ int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
|
||||||
for (size_t i = 0; i < perms_num; i++) {
|
for (size_t i = 0; i < perms_num; i++) {
|
||||||
ShapePush(perm, &perm_size, perm_data[i]);
|
ShapePush(perm, &perm_size, perm_data[i]);
|
||||||
}
|
}
|
||||||
|
if (perms_num == 4) {
|
||||||
|
const int nchw2nhwc[4] = {0, 2, 3, 1};
|
||||||
|
const int nhwc2nchw[4] = {0, 3, 1, 2};
|
||||||
|
const int trans3d[3] = {0, 2, 1};
|
||||||
|
if (input->format_ == Format_NCHW && CheckPermTransFormat(perm, nchw2nhwc, perms_num)) {
|
||||||
|
output->format_ = Format_NHWC;
|
||||||
|
} else if ((input->format_ == Format_NHWC || input->format_ == Format_KHWC) &&
|
||||||
|
CheckPermTransFormat(perm, nhwc2nchw, perms_num)) {
|
||||||
|
output->format_ = Format_NCHW;
|
||||||
|
}
|
||||||
|
// though the perm is 4d in default, the input can be a 3d tensor. The op implementation should be adapted to this.
|
||||||
|
if (input->shape_size_ == 3) {
|
||||||
|
ShapeSet(perm, &perm_size, trans3d, 3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (parameter->quant_type_ == QuantType_QUANT_WEIGHT) {
|
||||||
|
output->data_type_ = kNumberTypeFloat32;
|
||||||
|
}
|
||||||
|
if (!InferFlag(inputs, inputs_size)) {
|
||||||
|
return NNACL_INFER_INVALID;
|
||||||
|
}
|
||||||
|
|
||||||
// set output shape
|
// set output shape
|
||||||
int out_shape[MAX_TRANSPOSE_DIM_SIZE] = {0};
|
int out_shape[MAX_TRANSPOSE_DIM_SIZE] = {0};
|
||||||
SetOutputShape(perms_num, input, output, perm, perm_size, out_shape);
|
SetOutputShape(perms_num, input, output, perm, perm_size, out_shape);
|
||||||
|
|
|
@ -247,6 +247,7 @@ if(MSLITE_ENABLE_CONVERTER)
|
||||||
${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc
|
${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc
|
||||||
${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc
|
${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc
|
||||||
${LITE_DIR}/tools/optimizer/fusion/squeeze_fusion.cc
|
${LITE_DIR}/tools/optimizer/fusion/squeeze_fusion.cc
|
||||||
|
${LITE_DIR}/tools/optimizer/fusion/transpose_fusion.cc
|
||||||
${LITE_DIR}/tools/optimizer/graph/add_tensor_array.cc
|
${LITE_DIR}/tools/optimizer/graph/add_tensor_array.cc
|
||||||
${LITE_DIR}/tools/optimizer/graph/conv1d_weight_expanding_pass.cc
|
${LITE_DIR}/tools/optimizer/graph/conv1d_weight_expanding_pass.cc
|
||||||
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
|
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
|
||||||
|
@ -282,6 +283,7 @@ if(MSLITE_ENABLE_CONVERTER)
|
||||||
${LITE_DIR}/tools/common/node_util.cc
|
${LITE_DIR}/tools/common/node_util.cc
|
||||||
${LITE_DIR}/tools/common/storage.cc
|
${LITE_DIR}/tools/common/storage.cc
|
||||||
${LITE_DIR}/tools/converter/parser/inputs_adjust.cc
|
${LITE_DIR}/tools/converter/parser/inputs_adjust.cc
|
||||||
|
${LITE_DIR}/tools/converter/parser/insert_transpose.cc
|
||||||
${LITE_DIR}/tools/converter/parser/unused_node_remove_pass.cc
|
${LITE_DIR}/tools/converter/parser/unused_node_remove_pass.cc
|
||||||
${LITE_DIR}/tools/converter/parser/conv1d_inout_adjust.cc
|
${LITE_DIR}/tools/converter/parser/conv1d_inout_adjust.cc
|
||||||
${LITE_DIR}/tools/converter/parser/tf_bidirection_gru_cf_fusion.cc
|
${LITE_DIR}/tools/converter/parser/tf_bidirection_gru_cf_fusion.cc
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "tools/converter/quant_param_holder.h"
|
#include "tools/converter/quant_param_holder.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "tools/optimizer/common/format_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
@ -284,12 +285,15 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
if (prim->GetAttr(ops::kFormat) != nullptr) {
|
if (prim->GetAttr(ops::kFormat) != nullptr && !opt::CheckPrimitiveType(cnode, prim::kPrimResize)) {
|
||||||
auto value = prim->GetAttr(ops::kFormat);
|
auto value = prim->GetAttr(ops::kFormat);
|
||||||
if (value->isa<mindspore::Int64Imm>()) {
|
if (value->isa<mindspore::Int64Imm>()) {
|
||||||
data_info->format_ = GetValue<int64_t>(value);
|
data_info->format_ = GetValue<int64_t>(value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (!param_node->has_default()) {
|
||||||
|
data_info->format_ = NHWC;
|
||||||
|
}
|
||||||
// attr weightFormat is only used by conv-like ops' second input
|
// attr weightFormat is only used by conv-like ops' second input
|
||||||
if ((opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
|
if ((opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
|
||||||
opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) ||
|
opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) ||
|
||||||
|
@ -347,6 +351,31 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int SetFormatForCnode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
|
||||||
|
DataInfo *data_info) {
|
||||||
|
data_info->format_ = mindspore::NHWC;
|
||||||
|
auto input_node_prim = GetValueNode<PrimitivePtr>((cnode->input(index)->cast<CNodePtr>()->input(0)));
|
||||||
|
if (input_node_prim->GetAttr(ops::kFormat) != nullptr) {
|
||||||
|
auto value = input_node_prim->GetAttr(ops::kFormat);
|
||||||
|
if (value->isa<mindspore::Int64Imm>()) {
|
||||||
|
data_info->format_ = GetValue<int64_t>(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (opt::CheckPrimitiveType(cnode->input(index), prim::kPrimTranspose)) {
|
||||||
|
std::vector<int> perm;
|
||||||
|
if (opt::GetTransposePerm(cnode->input(index)->cast<CNodePtr>(), &perm) != RET_OK) {
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (perm[0] == 0 && perm[1] == 3 && perm[2] == 1 && perm[3] == 2 &&
|
||||||
|
(data_info->format_ == NHWC || data_info->format_ == KHWC)) {
|
||||||
|
data_info->format_ = NCHW;
|
||||||
|
} else if (perm[0] == 0 && perm[1] == 2 && perm[2] == 3 && perm[3] == 1 && data_info->format_ == NCHW) {
|
||||||
|
data_info->format_ = NHWC;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
|
int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
|
||||||
DataInfo *data_info) {
|
DataInfo *data_info) {
|
||||||
MS_ASSERT(cnode != nullptr && data_info != nullptr);
|
MS_ASSERT(cnode != nullptr && data_info != nullptr);
|
||||||
|
@ -368,7 +397,11 @@ int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType f
|
||||||
}
|
}
|
||||||
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
|
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
|
||||||
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
|
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
|
||||||
data_info->format_ = mindspore::NHWC;
|
auto ret = SetFormatForCnode(cnode, index, fmk_type, train_flag, data_info);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "set format for cnode failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
data_info->data_type_ = type_ptr->type_id();
|
data_info->data_type_ = type_ptr->type_id();
|
||||||
data_info->shape_ = dims;
|
data_info->shape_ = dims;
|
||||||
data_info->node_type_ = NodeType_CNode;
|
data_info->node_type_ = NodeType_CNode;
|
||||||
|
|
|
@ -36,6 +36,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/parser/unused_node_remove_pass.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/parser/unused_node_remove_pass.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/parser/conv1d_inout_adjust.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/parser/conv1d_inout_adjust.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/parser/inputs_adjust.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/parser/inputs_adjust.cc
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/parser/insert_transpose.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/import/mindspore_importer.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/import/mindspore_importer.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/import/primitive_adjust.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/import/primitive_adjust.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_adjust.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_adjust.cc
|
||||||
|
@ -73,6 +74,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
../optimizer/fusion/tf_gelu_fusion.cc
|
../optimizer/fusion/tf_gelu_fusion.cc
|
||||||
../optimizer/fusion/onnx_gelu_fusion.cc
|
../optimizer/fusion/onnx_gelu_fusion.cc
|
||||||
../optimizer/fusion/squeeze_fusion.cc
|
../optimizer/fusion/squeeze_fusion.cc
|
||||||
|
../optimizer/fusion/transpose_fusion.cc
|
||||||
../optimizer/fisson/eliminate_concat_split.cc
|
../optimizer/fisson/eliminate_concat_split.cc
|
||||||
../optimizer/fisson/fisson_util.cc
|
../optimizer/fisson/fisson_util.cc
|
||||||
../optimizer/fisson/iter_node_outputs.cc
|
../optimizer/fisson/iter_node_outputs.cc
|
||||||
|
|
|
@ -67,6 +67,7 @@
|
||||||
#include "tools/optimizer/parallel/parallel_pass.h"
|
#include "tools/optimizer/parallel/parallel_pass.h"
|
||||||
#include "include/registry/pass_registry.h"
|
#include "include/registry/pass_registry.h"
|
||||||
#include "tools/optimizer/fisson/multi_conv_split_pass.h"
|
#include "tools/optimizer/fisson/multi_conv_split_pass.h"
|
||||||
|
#include "tools/optimizer/fusion/transpose_fusion.h"
|
||||||
|
|
||||||
using std::string;
|
using std::string;
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
|
@ -82,6 +83,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
|
||||||
if (!config->trainModel) {
|
if (!config->trainModel) {
|
||||||
// remove quantdtype when awaretraining
|
// remove quantdtype when awaretraining
|
||||||
fusion_pm->AddPass(std::make_shared<opt::SqueezeFusion>());
|
fusion_pm->AddPass(std::make_shared<opt::SqueezeFusion>());
|
||||||
|
fusion_pm->AddPass(std::make_shared<opt::TransposeFusion>());
|
||||||
fusion_pm->AddPass(std::make_shared<opt::ReshapeReshapeFusion>());
|
fusion_pm->AddPass(std::make_shared<opt::ReshapeReshapeFusion>());
|
||||||
fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
|
fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
|
||||||
auto conv_bn_pass = std::make_shared<opt::ConvBatchNormFusion>();
|
auto conv_bn_pass = std::make_shared<opt::ConvBatchNormFusion>();
|
||||||
|
@ -338,7 +340,7 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
|
||||||
|
|
||||||
auto format_pass = std::make_shared<opt::UnifyFormatPass>();
|
auto format_pass = std::make_shared<opt::UnifyFormatPass>();
|
||||||
format_pass->Init(config->fmk, config->trainModel);
|
format_pass->Init(config->fmk, config->trainModel);
|
||||||
if (!format_pass->RunOnlyForShape(old_graph)) {
|
if (!format_pass->Run(old_graph)) {
|
||||||
MS_LOG(ERROR) << "Run format pass failed.";
|
MS_LOG(ERROR) << "Run format pass failed.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "tools/converter/import/mindir_adjust.h"
|
#include "tools/converter/import/mindir_adjust.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
|
#include "tools/converter/parser/insert_transpose.h"
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -202,6 +203,11 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_MS, flag.trainModel);
|
||||||
|
if (!insert_transpose->Run(func_graph)) {
|
||||||
|
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
if ((status = WeightFormatTransform(func_graph)) != RET_OK) {
|
if ((status = WeightFormatTransform(func_graph)) != RET_OK) {
|
||||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
#include "tools/converter/quant_param_holder.h"
|
#include "tools/converter/quant_param_holder.h"
|
||||||
#include "tools/converter/parser/parser_utils.h"
|
#include "tools/converter/parser/parser_utils.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
|
#include "tools/converter/parser/insert_transpose.h"
|
||||||
|
|
||||||
using mindspore::lite::converter::FmkType_CAFFE;
|
using mindspore::lite::converter::FmkType_CAFFE;
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
|
@ -103,6 +104,11 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag)
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_CAFFE, false);
|
||||||
|
if (!insert_transpose->Run(res_graph_)) {
|
||||||
|
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
|
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
|
||||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
|
|
|
@ -0,0 +1,511 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "tools/converter/parser/insert_transpose.h"
|
||||||
|
#include <queue>
|
||||||
|
#include <set>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "src/common/common.h"
|
||||||
|
#include "src/common/utils.h"
|
||||||
|
#include "tools/common/tensor_util.h"
|
||||||
|
|
||||||
|
using mindspore::lite::NCHW_SHAPE;
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
namespace {
|
||||||
|
constexpr size_t kNCHWDimNumber = 4;
|
||||||
|
const std::vector<int> NH2NC = {0, 3, 1, 2};
|
||||||
|
const std::vector<int> NC2NH = {0, 2, 3, 1};
|
||||||
|
bool IsSpecialType(const CNodePtr &cnode) {
|
||||||
|
if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || opt::CheckPrimitiveType(cnode, prim::kPrimDepend) ||
|
||||||
|
opt::CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || opt::CheckPrimitiveType(cnode, opt::kPrimMakeTupleV2) ||
|
||||||
|
opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void InsertTranspose::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) {
|
||||||
|
MS_ASSERT(cnode != nullptr);
|
||||||
|
auto prim_node = cnode->input(0);
|
||||||
|
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||||
|
MS_ASSERT(prim != nullptr);
|
||||||
|
auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
|
||||||
|
auto &specify_nchw_op_map = opt::GetNCHWOpMap();
|
||||||
|
if (fmk_type_ == lite::converter::FmkType_TFLITE) {
|
||||||
|
if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
trans_info->pre_ = opt::kNHWC2NCHW;
|
||||||
|
trans_info->post_ = opt::kNCHW2NHWC;
|
||||||
|
} else if (fmk_type_ == lite::converter::FmkType_TF) {
|
||||||
|
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && opt::GetFormat(cnode) == NCHW) {
|
||||||
|
trans_info->pre_ = opt::kNCHW2NHWC;
|
||||||
|
trans_info->post_ = opt::kNHWC2NCHW;
|
||||||
|
}
|
||||||
|
if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) {
|
||||||
|
trans_info->pre_ = opt::kNHWC2NCHW;
|
||||||
|
trans_info->post_ = opt::kNCHW2NHWC;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) {
|
||||||
|
if (fmk_type_ == lite::converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr &&
|
||||||
|
GetValue<int64_t>(prim->GetAttr(ops::kFormat)) == NHWC) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
trans_info->pre_ = opt::kNCHW2NHWC;
|
||||||
|
trans_info->post_ = opt::kNHWC2NCHW;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr InsertTranspose::GenNewInputWithoutShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||||
|
const std::vector<int> &perm, bool before, size_t index) {
|
||||||
|
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||||
|
AnfNodePtr new_input = nullptr;
|
||||||
|
AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode;
|
||||||
|
std::string trans_name =
|
||||||
|
before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post";
|
||||||
|
new_input = opt::GenTransposeNode(func_graph, trans_input_node, perm, trans_name);
|
||||||
|
auto new_input_prim = GetValueNode<PrimitivePtr>(new_input->cast<CNodePtr>()->input(0));
|
||||||
|
if (perm == NC2NH) {
|
||||||
|
new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
|
||||||
|
} else if (perm == NH2NC) {
|
||||||
|
new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
|
||||||
|
}
|
||||||
|
return new_input;
|
||||||
|
}
|
||||||
|
|
||||||
|
STATUS InsertTranspose::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm,
|
||||||
|
bool before, size_t index) {
|
||||||
|
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||||
|
AnfNodePtr new_input = nullptr;
|
||||||
|
|
||||||
|
new_input = GenNewInputWithoutShape(func_graph, cnode, perm, before, index);
|
||||||
|
if (new_input == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "generate a transpose node failed.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
if (new_input == cnode->input(index) || new_input == cnode) {
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
auto manager = func_graph->manager();
|
||||||
|
if (manager == nullptr) {
|
||||||
|
manager = Manage(func_graph, true);
|
||||||
|
}
|
||||||
|
MS_ASSERT(manager != nullptr);
|
||||||
|
auto tr = manager->Transact();
|
||||||
|
if (before) {
|
||||||
|
tr.SetEdge(cnode, index, new_input);
|
||||||
|
tr.Commit();
|
||||||
|
} else {
|
||||||
|
func_graph->manager()->Replace(cnode, new_input);
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
STATUS InsertTranspose::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||||
|
const std::vector<int> &perm) {
|
||||||
|
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||||
|
auto prim_node = cnode->input(0);
|
||||||
|
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||||
|
MS_ASSERT(prim != nullptr);
|
||||||
|
auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
|
||||||
|
auto &specify_nchw_op_map = opt::GetNCHWOpMap();
|
||||||
|
if (specify_nhwc_op_map.find(prim->name()) == specify_nhwc_op_map.end() &&
|
||||||
|
specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
|
||||||
|
MS_LOG(ERROR) << "op don't meet nhwc condition.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
std::vector<size_t> insert_index = specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()
|
||||||
|
? specify_nhwc_op_map.at(prim->name())
|
||||||
|
: specify_nchw_op_map.at(prim->name());
|
||||||
|
if (insert_index.empty()) {
|
||||||
|
if (opt::CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr &&
|
||||||
|
GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) {
|
||||||
|
insert_index.push_back(1);
|
||||||
|
} else {
|
||||||
|
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||||
|
insert_index.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto &index : insert_index) {
|
||||||
|
if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "generate a new input failed.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
STATUS InsertTranspose::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||||
|
const std::vector<int> &perm) {
|
||||||
|
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||||
|
if (!cnode->abstract()->isa<abstract::AbstractTuple>()) {
|
||||||
|
if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "generate a new input failed.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto node_users = func_graph->manager()->node_users()[cnode];
|
||||||
|
for (auto &node_user : node_users) {
|
||||||
|
auto post_node = node_user.first;
|
||||||
|
CNodePtr tuple_get_item = nullptr;
|
||||||
|
if (!opt::CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
|
||||||
|
if (!train_flag_) {
|
||||||
|
MS_LOG(ERROR) << "post node is invalid.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
} else {
|
||||||
|
tuple_get_item = opt::GenTupleGetItemNode(func_graph, cnode, 0);
|
||||||
|
post_node = tuple_get_item;
|
||||||
|
func_graph->manager()->Replace(cnode, tuple_get_item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (func_graph->manager()->node_users()[post_node].empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto post_cnode = post_node->cast<CNodePtr>();
|
||||||
|
if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "generate a new input failed.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
if (tuple_get_item != nullptr) {
|
||||||
|
func_graph->manager()->Replace(tuple_get_item, tuple_get_item->input(1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
STATUS InsertTranspose::HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||||
|
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||||
|
if (fmk_type_ == lite::converter::FmkType_TF || fmk_type_ == lite::converter::FmkType_TFLITE) {
|
||||||
|
return lite::RET_NO_CHANGE;
|
||||||
|
}
|
||||||
|
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||||
|
auto node = cnode->input(i);
|
||||||
|
if (!utils::isa<ParameterPtr>(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto param_node = node->cast<ParameterPtr>();
|
||||||
|
if (param_node->has_default()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto abstract_base = param_node->abstract();
|
||||||
|
if (abstract_base == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
|
||||||
|
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||||
|
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
|
||||||
|
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
|
||||||
|
if (shape_vector.size() != 4) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (func_graph->get_inputs().size() == 1 && fmk_type_ == lite::converter::FmkType_ONNX && shape_vector[3] == 3 &&
|
||||||
|
shape_vector[1] == -1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::vector<int64_t> new_dims = {shape_vector[NCHW_SHAPE::NCHW_N], shape_vector[NCHW_SHAPE::NCHW_H],
|
||||||
|
shape_vector[NCHW_SHAPE::NCHW_W], shape_vector[NCHW_SHAPE::NCHW_C]};
|
||||||
|
abstract_tensor->set_shape(std::make_shared<abstract::Shape>(new_dims));
|
||||||
|
auto trans_cnode = opt::GenTransposeNode(func_graph, param_node, NH2NC, param_node->fullname_with_scope() + "_pre");
|
||||||
|
auto new_input_prim = GetValueNode<PrimitivePtr>(trans_cnode->cast<CNodePtr>()->input(0));
|
||||||
|
new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
|
||||||
|
if (trans_cnode == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "generate a transpose node failed.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
func_graph->manager()->Replace(param_node, trans_cnode);
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
STATUS InsertTranspose::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||||
|
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||||
|
opt::TransTypePair trans_info;
|
||||||
|
GetTransNodeFormatType(cnode, &trans_info);
|
||||||
|
if (trans_info.pre_ == opt::kNONE || trans_info.post_ == opt::kNONE) {
|
||||||
|
return lite::RET_NO_CHANGE;
|
||||||
|
}
|
||||||
|
auto before_perm = trans_info.pre_ == opt::kNHWC2NCHW ? NH2NC : NC2NH;
|
||||||
|
auto after_perm = trans_info.post_ == opt::kNCHW2NHWC ? NC2NH : NH2NC;
|
||||||
|
if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope();
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam) || opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
void InsertTranspose::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||||
|
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||||
|
auto sub_inputs = sub_graph->get_inputs();
|
||||||
|
sub_inputs_map_[sub_graph] = sub_inputs;
|
||||||
|
for (auto &node : sub_inputs) {
|
||||||
|
auto param_node = node->cast<ParameterPtr>();
|
||||||
|
MS_ASSERT(param_node != nullptr);
|
||||||
|
auto node_name = node->fullname_with_scope();
|
||||||
|
auto last_underline = node_name.find_last_of("_");
|
||||||
|
node_name = node_name.substr(0, last_underline);
|
||||||
|
last_underline = node_name.find_last_of("_");
|
||||||
|
auto index = std::stoi(node_name.substr(last_underline + 1)) + 3;
|
||||||
|
param_node->set_abstract(opt::GetCNodeInputAbstract(cnode, index)->Clone());
|
||||||
|
if (utils::isa<CNodePtr>(cnode->input(index))) {
|
||||||
|
ShapeVector shape_vec = {-1};
|
||||||
|
auto out_cnode = cnode->input(index)->cast<CNodePtr>();
|
||||||
|
MS_ASSERT(trans_cnode != nullptr);
|
||||||
|
auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0));
|
||||||
|
if (out_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(opt::kInferDone))) {
|
||||||
|
param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vec));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
lite::DataInfo data_info;
|
||||||
|
if (utils::isa<ParameterPtr>(cnode->input(index))) {
|
||||||
|
if (cnode->input(index)->cast<ParameterPtr>()->has_default()) {
|
||||||
|
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info);
|
||||||
|
if (status != lite::RET_OK) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
|
||||||
|
if (data_info.data_.empty()) {
|
||||||
|
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec));
|
||||||
|
} else {
|
||||||
|
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec,
|
||||||
|
data_info.data_.data(), data_info.data_.size()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void InsertTranspose::ResetSubGraphInput() {
|
||||||
|
for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
|
||||||
|
auto &sub_graph = iter->first;
|
||||||
|
auto &sub_inputs = iter->second;
|
||||||
|
auto manager = sub_graph->manager();
|
||||||
|
MS_ASSERT(manager != nullptr);
|
||||||
|
for (auto &sub_input : sub_inputs) {
|
||||||
|
auto param_node = sub_graph->add_parameter();
|
||||||
|
MS_ASSERT(param_node != nullptr);
|
||||||
|
param_node->set_abstract(sub_input->abstract()->Clone());
|
||||||
|
param_node->set_name(sub_input->fullname_with_scope());
|
||||||
|
manager->Replace(sub_input, param_node);
|
||||||
|
auto sub_param_input = sub_input->cast<ParameterPtr>();
|
||||||
|
MS_ASSERT(sub_param_input != nullptr);
|
||||||
|
sub_param_input->set_default_param(nullptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void InsertTranspose::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||||
|
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||||
|
auto return_node = sub_graph->get_return();
|
||||||
|
auto origin_input = return_node->inputs();
|
||||||
|
lite::RemoveIfDepend(return_node);
|
||||||
|
lite::RemoveIfMakeTuple(return_node);
|
||||||
|
for (size_t i = 1; i < return_node->size(); ++i) {
|
||||||
|
if (!opt::CheckPrimitiveType(return_node->input(i), prim::kPrimTranspose)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto node_name = return_node->input(i)->fullname_with_scope();
|
||||||
|
if (node_name.substr(node_name.size() - 5) != "_post") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto trans_cnode = return_node->input(i)->cast<CNodePtr>();
|
||||||
|
MS_ASSERT(trans_cnode != nullptr);
|
||||||
|
auto trans_input = trans_cnode->input(1);
|
||||||
|
auto trans_input_name = trans_input->fullname_with_scope();
|
||||||
|
if (utils::isa<ParameterPtr>(trans_input)) {
|
||||||
|
trans_input->cast<ParameterPtr>()->set_name(node_name);
|
||||||
|
} else if (utils::isa<CNodePtr>(trans_input)) {
|
||||||
|
trans_input->cast<CNodePtr>()->set_fullname_with_scope(node_name);
|
||||||
|
}
|
||||||
|
trans_input_name = trans_input_name.substr(0, trans_input_name.find_last_of("_")) + "_cnode";
|
||||||
|
trans_cnode->set_fullname_with_scope(trans_input_name);
|
||||||
|
}
|
||||||
|
return_node->set_inputs(origin_input);
|
||||||
|
}
|
||||||
|
|
||||||
|
void InsertTranspose::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||||
|
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||||
|
auto return_node = sub_graph->get_return();
|
||||||
|
auto origin_inputs = return_node->inputs();
|
||||||
|
lite::RemoveIfDepend(return_node);
|
||||||
|
lite::RemoveIfMakeTuple(return_node);
|
||||||
|
AbstractBasePtrList abstract_list;
|
||||||
|
bool infer_done = true;
|
||||||
|
for (size_t i = 1; i < return_node->size(); ++i) {
|
||||||
|
auto abstract_base = opt::GetCNodeInputAbstract(return_node, i);
|
||||||
|
MS_ASSERT(abstract_base != nullptr);
|
||||||
|
abstract_list.emplace_back(abstract_base->Clone());
|
||||||
|
auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
|
||||||
|
MS_ASSERT(abstract_tensor != nullptr);
|
||||||
|
auto shape_ptr = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape());
|
||||||
|
MS_ASSERT(shape_ptr != nullptr);
|
||||||
|
auto shape = shape_ptr->shape();
|
||||||
|
if (std::find(shape.begin(), shape.end(), -1) != shape.end()) {
|
||||||
|
infer_done = false;
|
||||||
|
}
|
||||||
|
if (utils::isa<CNodePtr>(return_node->input(i))) {
|
||||||
|
auto input_cnode = return_node->input(i)->cast<CNodePtr>();
|
||||||
|
if (opt::CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
|
||||||
|
input_cnode = input_cnode->input(1)->cast<CNodePtr>();
|
||||||
|
}
|
||||||
|
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
|
||||||
|
if (input_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(opt::kInferDone))) {
|
||||||
|
infer_done = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return_node->set_inputs(origin_inputs);
|
||||||
|
if (utils::isa<abstract::AbstractTuplePtr>(cnode->abstract())) {
|
||||||
|
cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||||
|
} else {
|
||||||
|
if (abstract_list.size() != 1) {
|
||||||
|
MS_LOG(ERROR) << "cnode output is invalid.";
|
||||||
|
}
|
||||||
|
cnode->set_abstract(abstract_list.front());
|
||||||
|
}
|
||||||
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
|
prim->AddAttr(opt::kInferDone, MakeValue<bool>(infer_done));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool InsertTranspose::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) {
|
||||||
|
MS_ASSERT(func_graph != nullptr);
|
||||||
|
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
|
||||||
|
auto manager = Manage(func_graph, true);
|
||||||
|
if (manager == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "manager is nullptr.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto node_list = TopoSort(func_graph->get_return());
|
||||||
|
int status;
|
||||||
|
for (auto &node : node_list) {
|
||||||
|
if (!utils::isa<CNodePtr>(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (IsSpecialType(cnode)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (main_graph) {
|
||||||
|
status = HandleGraphInput(func_graph, cnode);
|
||||||
|
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||||
|
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||||
|
if (sub_func_graph == nullptr) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
SetSubGraphInput(cnode, sub_func_graph);
|
||||||
|
(void)BasicProcess(sub_func_graph, false);
|
||||||
|
SetSubGraphOutput(cnode, sub_func_graph);
|
||||||
|
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
|
||||||
|
if (sub_func_graph == nullptr) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
SetSubGraphInput(cnode, sub_func_graph);
|
||||||
|
(void)BasicProcess(sub_func_graph, false);
|
||||||
|
SetSubGraphOutput(cnode, sub_func_graph);
|
||||||
|
SetSubGraphAbstract(cnode, sub_func_graph);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
status = HandleGraphNode(func_graph, cnode);
|
||||||
|
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool InsertTranspose::ResetFuncGraph(const FuncGraphPtr &func_graph) {
|
||||||
|
MS_ASSERT(func_graph != nullptr);
|
||||||
|
auto manager = Manage(func_graph, true);
|
||||||
|
if (manager == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "manager is nullptr.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto node_list = TopoSort(func_graph->get_return());
|
||||||
|
for (auto &node : node_list) {
|
||||||
|
if (!utils::isa<CNodePtr>(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
|
if (prim->GetAttr(opt::kInferDone) != nullptr) {
|
||||||
|
prim->EraseAttr(opt::kInferDone);
|
||||||
|
}
|
||||||
|
if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||||
|
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||||
|
if (sub_func_graph == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
(void)ResetFuncGraph(sub_func_graph);
|
||||||
|
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
|
||||||
|
if (sub_func_graph == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
(void)ResetFuncGraph(sub_func_graph);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool InsertTranspose::Run(const FuncGraphPtr &func_graph) {
|
||||||
|
MS_ASSERT(func_graph != nullptr);
|
||||||
|
auto node_list = TopoSort(func_graph->get_return());
|
||||||
|
for (auto &node : node_list) {
|
||||||
|
auto prim = GetValueNode<PrimitivePtr>(node);
|
||||||
|
if (prim == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// insert transpose for some ops whose format must be NHWC, which is depend on framework.
|
||||||
|
// In this process, tranpose can be fused, which the original graph may not be able to restored.
|
||||||
|
if (!BasicProcess(func_graph, true)) {
|
||||||
|
MS_LOG(ERROR) << "run framework transpose unify failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ResetSubGraphInput();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,62 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_
|
||||||
|
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "utils/utils.h"
|
||||||
|
#include "tools/converter/converter_flags.h"
|
||||||
|
#include "tools/optimizer/common/format_utils.h"
|
||||||
|
#include "tools/anf_exporter/fetch_content.h"
|
||||||
|
|
||||||
|
using mindspore::lite::converter::FmkType;
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class InsertTranspose {
|
||||||
|
public:
|
||||||
|
InsertTranspose(FmkType fmk_type, bool train_flag) : fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||||
|
~InsertTranspose() = default;
|
||||||
|
bool Run(const FuncGraphPtr &func_graph);
|
||||||
|
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||||
|
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||||
|
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before,
|
||||||
|
size_t index = 0);
|
||||||
|
|
||||||
|
private:
|
||||||
|
AnfNodePtr GenNewInputWithoutShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||||
|
const std::vector<int> &perm, bool before, size_t index);
|
||||||
|
bool ResetFuncGraph(const FuncGraphPtr &func_graph);
|
||||||
|
bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph);
|
||||||
|
void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info);
|
||||||
|
STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||||
|
STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||||
|
void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||||
|
void ResetSubGraphInput();
|
||||||
|
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||||
|
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||||
|
FmkType fmk_type_{lite::converter::FmkType_MS};
|
||||||
|
bool train_flag_{false};
|
||||||
|
std::unordered_map<FuncGraphPtr, std::vector<AnfNodePtr>> sub_inputs_map_;
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_
|
|
@ -35,6 +35,7 @@
|
||||||
#include "tools/converter/parser/onnx/onnx_pad_adjust.h"
|
#include "tools/converter/parser/onnx/onnx_pad_adjust.h"
|
||||||
#include "tools/converter/parser/parser_utils.h"
|
#include "tools/converter/parser/parser_utils.h"
|
||||||
#include "ops/transpose.h"
|
#include "ops/transpose.h"
|
||||||
|
#include "tools/converter/parser/insert_transpose.h"
|
||||||
|
|
||||||
using mindspore::lite::converter::FmkType_ONNX;
|
using mindspore::lite::converter::FmkType_ONNX;
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -90,6 +91,11 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_ONNX, false);
|
||||||
|
if (!insert_transpose->Run(res_graph_)) {
|
||||||
|
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
if ((status = WeightFormatTransform(all_func_graphs)) != RET_OK) {
|
if ((status = WeightFormatTransform(all_func_graphs)) != RET_OK) {
|
||||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
|
|
|
@ -33,6 +33,7 @@
|
||||||
#include "tools/converter/parser/tf/functionalize_control_op_pass.h"
|
#include "tools/converter/parser/tf/functionalize_control_op_pass.h"
|
||||||
#include "tools/converter/parser/parser_utils.h"
|
#include "tools/converter/parser/parser_utils.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
|
#include "tools/converter/parser/insert_transpose.h"
|
||||||
|
|
||||||
using mindspore::lite::converter::FmkType_TF;
|
using mindspore::lite::converter::FmkType_TF;
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -575,6 +576,11 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_TF, false);
|
||||||
|
if (!insert_transpose->Run(res_graph_)) {
|
||||||
|
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
|
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
|
||||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
|
|
|
@ -30,6 +30,7 @@
|
||||||
#include "tools/converter/converter_flags.h"
|
#include "tools/converter/converter_flags.h"
|
||||||
#include "tools/converter/parser/tflite/tflite_inputs_adjust.h"
|
#include "tools/converter/parser/tflite/tflite_inputs_adjust.h"
|
||||||
#include "tools/converter/parser/parser_utils.h"
|
#include "tools/converter/parser/parser_utils.h"
|
||||||
|
#include "tools/converter/parser/insert_transpose.h"
|
||||||
|
|
||||||
using mindspore::lite::converter::FmkType_TFLITE;
|
using mindspore::lite::converter::FmkType_TFLITE;
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
|
@ -104,6 +105,11 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_TFLITE, false);
|
||||||
|
if (!insert_transpose->Run(res_graph_)) {
|
||||||
|
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
|
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
|
||||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
|
|
|
@ -26,7 +26,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
constexpr auto kInferDone = "infer_done";
|
constexpr auto kInferDone = "infer_done";
|
||||||
constexpr auto kTransDone = "trans_done";
|
|
||||||
enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE };
|
enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE };
|
||||||
struct TransTypePair {
|
struct TransTypePair {
|
||||||
FormatTransNodeType pre_;
|
FormatTransNodeType pre_;
|
||||||
|
|
|
@ -0,0 +1,176 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "tools/optimizer/fusion/transpose_fusion.h"
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include "tools/converter/quant_param_holder.h"
|
||||||
|
#include "mindspore/core/ops/transpose.h"
|
||||||
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore::opt {
|
||||||
|
namespace {
|
||||||
|
const std::vector<int> NH2NC = {0, 3, 1, 2};
|
||||||
|
const std::vector<int> NC2NH = {0, 2, 3, 1};
|
||||||
|
} // namespace
|
||||||
|
bool IsBNCNode(const BaseRef &n) {
|
||||||
|
if (utils::isa<AnfNodePtr>(n)) {
|
||||||
|
auto anf_node = utils::cast<AnfNodePtr>(n);
|
||||||
|
return CheckPrimitiveType(anf_node, prim::kPrimBatchNorm) ||
|
||||||
|
CheckPrimitiveType(anf_node, prim::kPrimFusedBatchNorm);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorRef TransposeFusion::DefineBNPattern() const {
|
||||||
|
auto transpose_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
|
||||||
|
auto conv_var = std::make_shared<CondVar>(IsConvNode);
|
||||||
|
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
|
||||||
|
VectorRef transpose_conv_ref = VectorRef({transpose_var, conv_var, transpose_param});
|
||||||
|
auto bn_var = std::make_shared<CondVar>(IsBNCNode);
|
||||||
|
auto bn_mean_var = std::make_shared<CondVar>(IsParamNode);
|
||||||
|
auto bn_variable_var = std::make_shared<CondVar>(IsParamNode);
|
||||||
|
auto bn_other_var = std::make_shared<SeqVar>();
|
||||||
|
VectorRef bn_ref = VectorRef({bn_var, transpose_conv_ref, bn_mean_var, bn_variable_var, bn_other_var});
|
||||||
|
return bn_ref;
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorRef TransposeFusion::DefineActivationscalePattern() const {
|
||||||
|
auto transpose_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
|
||||||
|
auto conv_var = std::make_shared<CondVar>(IsConvNode);
|
||||||
|
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
|
||||||
|
VectorRef transpose_conv_ref = VectorRef({transpose_var, conv_var, transpose_param});
|
||||||
|
auto scale_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimScaleFusion>);
|
||||||
|
auto scale_var_1 = std::make_shared<CondVar>(IsParamNode);
|
||||||
|
auto scale_var_2 = std::make_shared<SeqVar>();
|
||||||
|
VectorRef sclae_ref = VectorRef({scale_var, transpose_conv_ref, scale_var_1, scale_var_2});
|
||||||
|
return sclae_ref;
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorRef TransposeFusion::DefineActivationPattern() const {
|
||||||
|
auto transpose_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
|
||||||
|
auto conv_var = std::make_shared<CondVar>(IsConvNode);
|
||||||
|
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
|
||||||
|
VectorRef transpose_conv_ref = VectorRef({transpose_var, conv_var, transpose_param});
|
||||||
|
auto act_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimActivation>);
|
||||||
|
VectorRef act_ref = VectorRef({act_var, transpose_conv_ref});
|
||||||
|
return act_ref;
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorRef TransposeFusion::DefineBiasAddPattern() const {
|
||||||
|
auto transpose_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
|
||||||
|
auto conv_var = std::make_shared<CondVar>(IsConvNode);
|
||||||
|
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
|
||||||
|
VectorRef transpose_conv_ref = VectorRef({transpose_var, conv_var, transpose_param});
|
||||||
|
auto bias_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBiasAdd>);
|
||||||
|
auto bias_param = std::make_shared<CondVar>(IsParamNode);
|
||||||
|
VectorRef act_ref = VectorRef({bias_var, transpose_conv_ref, bias_param});
|
||||||
|
return act_ref;
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorRef TransposeFusion::DefineTransTransPattern() const {
|
||||||
|
auto transpose_var_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
|
||||||
|
auto transpose_var_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
|
||||||
|
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
|
||||||
|
VectorRef trans_trans_ref = VectorRef({transpose_var_2, transpose_var_1, transpose_param});
|
||||||
|
return trans_trans_ref;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_map<std::string, VectorRef> TransposeFusion::DefinePatterns() const {
|
||||||
|
std::unordered_map<std::string, VectorRef> patterns;
|
||||||
|
patterns["BNPatternName"] = DefineBNPattern();
|
||||||
|
patterns["ActivationPatternName"] = DefineActivationPattern();
|
||||||
|
patterns["BiasAddPatternName"] = DefineBiasAddPattern();
|
||||||
|
patterns["ScalePatternName"] = DefineActivationscalePattern();
|
||||||
|
patterns["TransTransPatternName"] = DefineTransTransPattern();
|
||||||
|
return patterns;
|
||||||
|
}
|
||||||
|
|
||||||
|
CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const AnfNodePtr &perm,
|
||||||
|
const std::string &cnode_name) {
|
||||||
|
MS_ASSERT(func_graph != nullptr && input_node != nullptr);
|
||||||
|
auto trans_prim = std::make_shared<ops::Transpose>();
|
||||||
|
MS_ASSERT(trans_prim != nullptr);
|
||||||
|
auto cnode = func_graph->NewCNode(trans_prim, {input_node, perm});
|
||||||
|
MS_ASSERT(cnode != nullptr);
|
||||||
|
cnode->set_fullname_with_scope(cnode_name);
|
||||||
|
auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(2, 1);
|
||||||
|
auto trans_insert_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
|
trans_insert_prim->AddAttr("quant_params", quant_params_holder);
|
||||||
|
return cnode;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr TransposeFusion::TransTransFusion(const mindspore::FuncGraphPtr &func_graph,
|
||||||
|
const mindspore::AnfNodePtr &node) const {
|
||||||
|
MS_ASSERT(func_graph != nullptr);
|
||||||
|
MS_ASSERT(node != nullptr);
|
||||||
|
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto trans_cnode_2 = node->cast<CNodePtr>();
|
||||||
|
if (!CheckPrimitiveType(trans_cnode_2, prim::kPrimTranspose) ||
|
||||||
|
!CheckPrimitiveType(trans_cnode_2->input(1), prim::kPrimTranspose)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::vector<int> post_perm;
|
||||||
|
if (GetTransposePerm(trans_cnode_2, &post_perm) != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "get tanspose perm failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::vector<int> pre_perm;
|
||||||
|
auto pre_node = trans_cnode_2->input(1);
|
||||||
|
auto pre_cnode = pre_node->cast<CNodePtr>();
|
||||||
|
if (pre_cnode == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "get tanspose perm failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if ((pre_perm == NH2NC && post_perm == NC2NH) || (pre_perm == NC2NH && post_perm == NH2NC)) {
|
||||||
|
return pre_cnode->input(1);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr TransposeFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
|
||||||
|
const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
|
||||||
|
if (pattern_name == "TransTransPatternName") {
|
||||||
|
return TransTransFusion(func_graph, node);
|
||||||
|
}
|
||||||
|
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (CheckIfCNodeIsNull(node->cast<CNodePtr>()) != lite::RET_OK) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto any_cnode = node->cast<CNodePtr>();
|
||||||
|
const auto transpose_node = any_cnode->input(1);
|
||||||
|
if (CheckIfCNodeIsNull(transpose_node->cast<CNodePtr>()) != lite::RET_OK) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
const CNodePtr &transpose_cnode = transpose_node->cast<CNodePtr>();
|
||||||
|
auto perm_node = transpose_cnode->input(2);
|
||||||
|
auto trans_post_node = GenTransposeNode(func_graph, any_cnode, perm_node, any_cnode->fullname_with_scope() + "_post");
|
||||||
|
auto tr = func_graph->manager()->Transact();
|
||||||
|
tr.SetEdge(any_cnode, 1, transpose_cnode->input(1));
|
||||||
|
tr.Commit();
|
||||||
|
return trans_post_node;
|
||||||
|
}
|
||||||
|
} // namespace mindspore::opt
|
|
@ -0,0 +1,48 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_TRANSPOSE_FUSION_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_PASS_FUSION_TRANSPOSE_FUSION_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "tools/optimizer/graph/unify_format_pass.h"
|
||||||
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
|
#include "schema/inner/model_generated.h"
|
||||||
|
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class TransposeFusion : public MultiplePatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit TransposeFusion(const std::string &name = "transpose_fusion", bool multigraph = true)
|
||||||
|
: MultiplePatternProcessPass(name, multigraph) {}
|
||||||
|
|
||||||
|
~TransposeFusion() override = default;
|
||||||
|
|
||||||
|
std::unordered_map<std::string, VectorRef> DefinePatterns() const override;
|
||||||
|
VectorRef DefineBNPattern() const;
|
||||||
|
VectorRef DefineActivationPattern() const;
|
||||||
|
VectorRef DefineActivationscalePattern() const;
|
||||||
|
VectorRef DefineTransTransPattern() const;
|
||||||
|
VectorRef DefineBiasAddPattern() const;
|
||||||
|
AnfNodePtr TransTransFusion(const mindspore::FuncGraphPtr &func_graph, const mindspore::AnfNodePtr &node) const;
|
||||||
|
AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &,
|
||||||
|
const EquivPtr &) const override;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_TRANSPOSE_FUSION_H_
|
|
@ -165,6 +165,8 @@ STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
|
||||||
}
|
}
|
||||||
if (ret == lite::RET_OK || ret == lite::RET_INFER_INVALID) {
|
if (ret == lite::RET_OK || ret == lite::RET_INFER_INVALID) {
|
||||||
auto set_status = SetCNodeAbstract(cnode, outputs, ret);
|
auto set_status = SetCNodeAbstract(cnode, outputs, ret);
|
||||||
|
auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
|
cnode_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(inputs[0]->format()));
|
||||||
if (set_status != lite::RET_OK) {
|
if (set_status != lite::RET_OK) {
|
||||||
MS_LOG(ERROR) << "set CNode abstract failed: " << cnode->fullname_with_scope();
|
MS_LOG(ERROR) << "set CNode abstract failed: " << cnode->fullname_with_scope();
|
||||||
FreeTensors(&inputs);
|
FreeTensors(&inputs);
|
||||||
|
@ -336,6 +338,7 @@ STATUS NodeInferShape::GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite:
|
||||||
tensor = GetCNodeTensorListVarInput(data_info);
|
tensor = GetCNodeTensorListVarInput(data_info);
|
||||||
} else {
|
} else {
|
||||||
tensor = new (std::nothrow) lite::Tensor(TypeId(data_info.data_type_), data_info.shape_);
|
tensor = new (std::nothrow) lite::Tensor(TypeId(data_info.data_type_), data_info.shape_);
|
||||||
|
tensor->set_format((Format)(data_info.format_));
|
||||||
}
|
}
|
||||||
if (tensor == nullptr) {
|
if (tensor == nullptr) {
|
||||||
MS_LOG(ERROR) << "new a lite tensor failed";
|
MS_LOG(ERROR) << "new a lite tensor failed";
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <map>
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
#include "src/tensor.h"
|
#include "src/tensor.h"
|
||||||
#include "tools/anf_exporter/fetch_content.h"
|
#include "tools/anf_exporter/fetch_content.h"
|
||||||
|
@ -41,9 +42,9 @@ class NodeInferShape {
|
||||||
bool JudgeOpSupportInfer(const CNodePtr &cnode);
|
bool JudgeOpSupportInfer(const CNodePtr &cnode);
|
||||||
std::vector<int> GetInputShape(const CNodePtr &cnode, size_t index);
|
std::vector<int> GetInputShape(const CNodePtr &cnode, size_t index);
|
||||||
std::vector<int> GetIntVecInput(const CNodePtr &cnode, size_t index);
|
std::vector<int> GetIntVecInput(const CNodePtr &cnode, size_t index);
|
||||||
|
STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *inputs);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *inputs);
|
|
||||||
STATUS GetCNodeConstInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *const_ms_inputs);
|
STATUS GetCNodeConstInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *const_ms_inputs);
|
||||||
STATUS GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *var_ms_inputs);
|
STATUS GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *var_ms_inputs);
|
||||||
lite::Tensor *GetCNodeTensorListVarInput(const lite::DataInfo &data_info);
|
lite::Tensor *GetCNodeTensorListVarInput(const lite::DataInfo &data_info);
|
||||||
|
|
|
@ -32,7 +32,13 @@ constexpr size_t kInputTripleNum = 3;
|
||||||
int ProcessInputIsMonad(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
int ProcessInputIsMonad(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||||
auto first_input = cnode->input(1);
|
auto first_input = cnode->input(1);
|
||||||
|
if (CheckPrimitiveType(first_input, prim::kPrimTranspose)) {
|
||||||
|
first_input = cnode->input(1)->cast<CNodePtr>()->input(1);
|
||||||
|
}
|
||||||
auto second_input = cnode->input(2);
|
auto second_input = cnode->input(2);
|
||||||
|
if (CheckPrimitiveType(second_input, prim::kPrimTranspose)) {
|
||||||
|
second_input = cnode->input(2)->cast<CNodePtr>()->input(1);
|
||||||
|
}
|
||||||
AnfNodePtr must_monad = nullptr;
|
AnfNodePtr must_monad = nullptr;
|
||||||
AnfNodePtr not_must_monad = nullptr;
|
AnfNodePtr not_must_monad = nullptr;
|
||||||
if (utils::isa<ValueNode>(first_input)) {
|
if (utils::isa<ValueNode>(first_input)) {
|
||||||
|
@ -72,6 +78,12 @@ int ProcessDependencyWithTwoNodes(const FuncGraphPtr &func_graph, const CNodePtr
|
||||||
pre_node = cnode->input(2);
|
pre_node = cnode->input(2);
|
||||||
post_node = cnode->input(1);
|
post_node = cnode->input(1);
|
||||||
}
|
}
|
||||||
|
if (CheckPrimitiveType(pre_node, prim::kPrimTranspose)) {
|
||||||
|
pre_node = cnode->input(1)->cast<CNodePtr>()->input(1);
|
||||||
|
}
|
||||||
|
if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) {
|
||||||
|
post_node = cnode->input(2)->cast<CNodePtr>()->input(1);
|
||||||
|
}
|
||||||
auto manager = func_graph->manager();
|
auto manager = func_graph->manager();
|
||||||
MS_ASSERT(manager != nullptr);
|
MS_ASSERT(manager != nullptr);
|
||||||
auto node_users = manager->node_users()[pre_node];
|
auto node_users = manager->node_users()[pre_node];
|
||||||
|
@ -102,6 +114,10 @@ int ProcessInputHaveDependency(const FuncGraphPtr &func_graph, const CNodePtr &c
|
||||||
auto make_tuple_prim = NewValueNode(std::make_shared<lite::MakeTuple>());
|
auto make_tuple_prim = NewValueNode(std::make_shared<lite::MakeTuple>());
|
||||||
auto manager = func_graph->manager();
|
auto manager = func_graph->manager();
|
||||||
MS_ASSERT(manager != nullptr);
|
MS_ASSERT(manager != nullptr);
|
||||||
|
if (CheckPrimitiveType(cnode->input(0), prim::kPrimTranspose)) {
|
||||||
|
manager->Replace(cnode->input(0)->cast<CNodePtr>()->input(0), make_tuple_prim);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
manager->Replace(cnode->input(0), make_tuple_prim);
|
manager->Replace(cnode->input(0), make_tuple_prim);
|
||||||
return lite::RET_OK;
|
return lite::RET_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -335,6 +335,9 @@ STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodeP
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto manager = func_graph->manager();
|
auto manager = func_graph->manager();
|
||||||
|
if (manager == nullptr) {
|
||||||
|
manager = Manage(func_graph, true);
|
||||||
|
}
|
||||||
MS_ASSERT(manager != nullptr);
|
MS_ASSERT(manager != nullptr);
|
||||||
auto tr = manager->Transact();
|
auto tr = manager->Transact();
|
||||||
if (before) {
|
if (before) {
|
||||||
|
@ -428,6 +431,7 @@ STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
status = node_infer_shape_.InferShape(cnode);
|
status = node_infer_shape_.InferShape(cnode);
|
||||||
|
|
||||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
||||||
MS_LOG(ERROR) << "infer shape failed.";
|
MS_LOG(ERROR) << "infer shape failed.";
|
||||||
return lite::RET_ERROR;
|
return lite::RET_ERROR;
|
||||||
|
@ -523,47 +527,8 @@ STATUS UnifyFormatPass::HandleGraphInput(const FuncGraphPtr &func_graph, const C
|
||||||
MS_LOG(ERROR) << "infer shape failed.";
|
MS_LOG(ERROR) << "infer shape failed.";
|
||||||
return lite::RET_ERROR;
|
return lite::RET_ERROR;
|
||||||
}
|
}
|
||||||
func_graph->manager()->Replace(param_node, trans_cnode);
|
|
||||||
}
|
|
||||||
return lite::RET_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
func_graph->manager()->Replace(param_node, trans_cnode);
|
||||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
|
||||||
auto prim_node = cnode->input(0);
|
|
||||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
|
||||||
MS_ASSERT(prim != nullptr);
|
|
||||||
if (prim->GetAttr(kTransDone) != nullptr && GetValue<bool>(prim->GetAttr(kTransDone))) {
|
|
||||||
return lite::RET_OK;
|
|
||||||
}
|
|
||||||
prim->AddAttr(kTransDone, MakeValue<bool>(true));
|
|
||||||
TransTypePair trans_info;
|
|
||||||
GetTransNodeFormatType(cnode, &trans_info);
|
|
||||||
if (trans_info.pre_ == kNONE || trans_info.post_ == kNONE) {
|
|
||||||
if (!need_reset_ && TransTransFusion(func_graph, cnode)) {
|
|
||||||
return lite::RET_OK;
|
|
||||||
}
|
|
||||||
auto status = node_infer_shape_.InferShape(cnode);
|
|
||||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
|
||||||
MS_LOG(ERROR) << "infer shape failed: " << cnode->fullname_with_scope();
|
|
||||||
return lite::RET_ERROR;
|
|
||||||
}
|
|
||||||
return lite::RET_NO_CHANGE;
|
|
||||||
}
|
|
||||||
auto before_perm = trans_info.pre_ == kNHWC2NCHW ? NH2NC : NC2NH;
|
|
||||||
auto after_perm = trans_info.post_ == kNCHW2NHWC ? NC2NH : NH2NC;
|
|
||||||
if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope();
|
|
||||||
return lite::RET_ERROR;
|
|
||||||
}
|
|
||||||
auto status = node_infer_shape_.InferShape(cnode);
|
|
||||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
|
||||||
MS_LOG(ERROR) << "infer shape failed.";
|
|
||||||
return lite::RET_ERROR;
|
|
||||||
}
|
|
||||||
if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
|
|
||||||
return lite::RET_ERROR;
|
|
||||||
}
|
}
|
||||||
return lite::RET_OK;
|
return lite::RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -763,58 +728,6 @@ void UnifyFormatPass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraph
|
||||||
prim->AddAttr(kInferDone, MakeValue<bool>(infer_done));
|
prim->AddAttr(kInferDone, MakeValue<bool>(infer_done));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) {
|
|
||||||
MS_ASSERT(func_graph != nullptr);
|
|
||||||
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
|
|
||||||
auto manager = Manage(func_graph, true);
|
|
||||||
if (manager == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "manager is nullptr.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto node_list = TopoSort(func_graph->get_return());
|
|
||||||
int status;
|
|
||||||
for (auto &node : node_list) {
|
|
||||||
if (!utils::isa<CNodePtr>(node)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto cnode = node->cast<CNodePtr>();
|
|
||||||
if (IsSpecialType(cnode)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (main_graph && !need_reset_) {
|
|
||||||
status = HandleGraphInput(func_graph, cnode);
|
|
||||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
|
|
||||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
|
||||||
if (sub_func_graph == nullptr) {
|
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
SetSubGraphInput(cnode, sub_func_graph);
|
|
||||||
(void)BasicProcess(sub_func_graph, false);
|
|
||||||
SetSubGraphOutput(cnode, sub_func_graph);
|
|
||||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
|
|
||||||
if (sub_func_graph == nullptr) {
|
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
SetSubGraphInput(cnode, sub_func_graph);
|
|
||||||
(void)BasicProcess(sub_func_graph, false);
|
|
||||||
SetSubGraphOutput(cnode, sub_func_graph);
|
|
||||||
SetSubGraphAbstract(cnode, sub_func_graph);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
status = HandleGraphNode(func_graph, cnode);
|
|
||||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) {
|
bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) {
|
||||||
MS_ASSERT(func_graph != nullptr);
|
MS_ASSERT(func_graph != nullptr);
|
||||||
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
|
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
|
||||||
|
@ -935,33 +848,6 @@ bool UnifyFormatPass::ResetFuncGraph(const FuncGraphPtr &func_graph) {
|
||||||
if (prim->GetAttr(kInferDone) != nullptr) {
|
if (prim->GetAttr(kInferDone) != nullptr) {
|
||||||
prim->EraseAttr(kInferDone);
|
prim->EraseAttr(kInferDone);
|
||||||
}
|
}
|
||||||
if (prim->GetAttr(kTransDone) != nullptr) {
|
|
||||||
prim->EraseAttr(kTransDone);
|
|
||||||
}
|
|
||||||
if (pre_insert_trans_.find(cnode) != pre_insert_trans_.end()) {
|
|
||||||
manager->Replace(node, cnode->input(1));
|
|
||||||
}
|
|
||||||
if (post_insert_trans_.find(cnode) != post_insert_trans_.end()) {
|
|
||||||
auto cnode_abstract = cnode->abstract();
|
|
||||||
if (!utils::isa<abstract::AbstractTensorPtr>(cnode_abstract)) {
|
|
||||||
MS_LOG(ERROR) << "abstract is not abstract tensor.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto cnode_abstract_tensor = cnode_abstract->cast<abstract::AbstractTensorPtr>();
|
|
||||||
if (!utils::isa<abstract::ShapePtr>(cnode_abstract_tensor->BuildShape())) {
|
|
||||||
MS_LOG(ERROR) << "shape of abstract tensor should be ShapePtr.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto shape_ptr = utils::cast<abstract::ShapePtr>(cnode_abstract_tensor->BuildShape());
|
|
||||||
auto input_abstract = GetCNodeInputAbstract(cnode, 1);
|
|
||||||
if (!utils::isa<abstract::AbstractTensorPtr>(input_abstract)) {
|
|
||||||
MS_LOG(ERROR) << "abstract is not abstract tensor.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto input_abstract_tensor = input_abstract->cast<abstract::AbstractTensorPtr>();
|
|
||||||
input_abstract_tensor->set_shape(shape_ptr);
|
|
||||||
manager->Replace(node, cnode->input(1));
|
|
||||||
}
|
|
||||||
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
|
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||||
if (sub_func_graph == nullptr) {
|
if (sub_func_graph == nullptr) {
|
||||||
|
@ -1025,18 +911,140 @@ bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) {
|
||||||
MS_LOG(ERROR) << "exist op cannot support infer shape.";
|
MS_LOG(ERROR) << "exist op cannot support infer shape.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
need_reset_ = true;
|
if (!RunNodeInferShape(func_graph)) {
|
||||||
// insert transpose for some ops whose format must be NHWC, which is depend on framework.
|
MS_LOG(ERROR) << "RunNodeInferShape failed.";
|
||||||
// In this process, transpose op cannot be fused to restore the original graph.
|
|
||||||
if (!BasicProcess(func_graph, true)) {
|
|
||||||
MS_LOG(ERROR) << "run framework transpose unify failed.";
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
ResetSubGraphInput();
|
ResetSubGraphInput();
|
||||||
// delete insert transpose op and update op output shape.
|
ResetFuncGraph(func_graph);
|
||||||
if (!ResetFuncGraph(func_graph)) {
|
return true;
|
||||||
MS_LOG(ERROR) << "reset func_graph failed.";
|
}
|
||||||
return false;
|
|
||||||
|
bool UnifyFormatPass::RunNodeInferShape(const FuncGraphPtr &func_graph) {
|
||||||
|
auto node_list = TopoSort(func_graph->get_return());
|
||||||
|
for (auto &node : node_list) {
|
||||||
|
if (!utils::isa<CNodePtr>(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (IsSpecialType(cnode)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
|
||||||
|
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||||
|
if (sub_func_graph == nullptr) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
SetSubGraphInput(cnode, sub_func_graph);
|
||||||
|
if (!RunNodeInferShape(sub_func_graph)) {
|
||||||
|
MS_LOG(ERROR) << "subgraph infer shape failed.";
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
SetSubGraphOutput(cnode, sub_func_graph);
|
||||||
|
|
||||||
|
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
|
||||||
|
if (sub_func_graph == nullptr) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
SetSubGraphInput(cnode, sub_func_graph);
|
||||||
|
if (!RunNodeInferShape(sub_func_graph)) {
|
||||||
|
MS_LOG(ERROR) << "subgraph infer shape failed.";
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
SetSubGraphOutput(cnode, sub_func_graph);
|
||||||
|
SetSubGraphAbstract(cnode, sub_func_graph);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto status = node_infer_shape_.InferShape(cnode);
|
||||||
|
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
||||||
|
MS_LOG(ERROR) << "infer shape failed." << cnode->fullname_with_scope();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool UnifyFormatPass::RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||||
|
auto prim_node = cnode->input(0);
|
||||||
|
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||||
|
auto &nchw_op = GetNCHWOpMap();
|
||||||
|
if (!utils::isa<CNodePtr>(cnode->input(1))) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (utils::isa<CNodePtr>(cnode->input(1))) {
|
||||||
|
auto format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
|
||||||
|
if (nchw_op.find(prim->name()) != nchw_op.end() && format != NCHW) {
|
||||||
|
InsertPreTransNode(func_graph, cnode, {0, 3, 1, 2});
|
||||||
|
InsertPostTransNode(func_graph, cnode, {0, 2, 3, 1});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
if (CheckPrimitiveType(cnode, prim::kPrimTranspose)) {
|
||||||
|
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||||
|
auto manager = func_graph->manager();
|
||||||
|
if (manager == nullptr) {
|
||||||
|
manager = Manage(func_graph, true);
|
||||||
|
}
|
||||||
|
auto shape = node_infer_shape_.GetInputShape(cnode, 1);
|
||||||
|
std::vector<int> perm;
|
||||||
|
auto status = GetTransposePerm(cnode, &perm);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!shape.empty() && shape.size() != perm.size()) {
|
||||||
|
manager->Replace(cnode, cnode->input(1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool UnifyFormatPass::DoFixFormat(const FuncGraphPtr &func_graph) {
|
||||||
|
auto node_list = TopoSort(func_graph->get_return());
|
||||||
|
for (auto &node : node_list) {
|
||||||
|
if (!utils::isa<CNodePtr>(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (IsSpecialType(cnode)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
|
||||||
|
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||||
|
if (sub_func_graph == nullptr) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
SetSubGraphInput(cnode, sub_func_graph);
|
||||||
|
if (!DoFixFormat(sub_func_graph)) {
|
||||||
|
MS_LOG(ERROR) << "subgraph infer shape failed.";
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
SetSubGraphOutput(cnode, sub_func_graph);
|
||||||
|
|
||||||
|
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
|
||||||
|
if (sub_func_graph == nullptr) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
SetSubGraphInput(cnode, sub_func_graph);
|
||||||
|
if (!DoFixFormat(sub_func_graph)) {
|
||||||
|
MS_LOG(ERROR) << "subgraph infer shape failed.";
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
SetSubGraphOutput(cnode, sub_func_graph);
|
||||||
|
SetSubGraphAbstract(cnode, sub_func_graph);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!RunDoFixFormat(func_graph, cnode)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -1049,22 +1057,23 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) {
|
||||||
if (prim == nullptr) {
|
if (prim == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (prim->GetAttr(kTransDone) != nullptr) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (!JudgeAllOpsCanInfer(func_graph)) {
|
if (!JudgeAllOpsCanInfer(func_graph)) {
|
||||||
MS_LOG(ERROR) << "exist op cannot support infer shape.";
|
MS_LOG(ERROR) << "exist op cannot support infer shape.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// insert transpose for some ops whose format must be NHWC, which is depend on framework.
|
if (!RunNodeInferShape(func_graph)) {
|
||||||
// In this process, tranpose can be fused, which the original graph may not be able to restored.
|
MS_LOG(ERROR) << "infer shape failed.";
|
||||||
if (!BasicProcess(func_graph, true)) {
|
|
||||||
MS_LOG(ERROR) << "run framework transpose unify failed.";
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
ResetSubGraphInput();
|
ResetSubGraphInput();
|
||||||
// if input format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op.
|
|
||||||
|
if (!DoFixFormat(func_graph)) {
|
||||||
|
MS_LOG(ERROR) << "DoFixFormat failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ResetSubGraphInput();
|
||||||
|
|
||||||
if (!DecreaseTransposeForSingleOp(func_graph)) {
|
if (!DecreaseTransposeForSingleOp(func_graph)) {
|
||||||
MS_LOG(ERROR) << "run local trans insert optimizer failed.";
|
MS_LOG(ERROR) << "run local trans insert optimizer failed.";
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -45,23 +45,25 @@ class UnifyFormatPass : public Pass {
|
||||||
bool RunOnlyForShape(const FuncGraphPtr &func_graph);
|
bool RunOnlyForShape(const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||||
|
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||||
|
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before,
|
||||||
|
size_t index = 0);
|
||||||
|
bool RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||||
|
bool DoFixFormat(const FuncGraphPtr &func_graph);
|
||||||
|
bool RunNodeInferShape(const FuncGraphPtr &func_graph);
|
||||||
bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph);
|
bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph);
|
||||||
bool ResetFuncGraph(const FuncGraphPtr &func_graph);
|
bool ResetFuncGraph(const FuncGraphPtr &func_graph);
|
||||||
bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph);
|
|
||||||
bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph);
|
bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph);
|
||||||
bool DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph);
|
bool DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph);
|
||||||
bool TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
bool TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||||
STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||||
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before,
|
|
||||||
size_t index = 0);
|
|
||||||
void GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info);
|
void GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info);
|
||||||
STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||||
STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
|
||||||
STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||||
std::set<CNodePtr> *visit_transposes);
|
std::set<CNodePtr> *visit_transposes);
|
||||||
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
|
||||||
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info);
|
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info);
|
||||||
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
|
||||||
void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||||
void ResetSubGraphInput();
|
void ResetSubGraphInput();
|
||||||
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||||
|
|
Loading…
Reference in New Issue