!20196 [LITE] format
Merge pull request !20196 from yefeng/130-format-nnie
This commit is contained in:
commit
c0179f61d8
|
@ -27,20 +27,6 @@ bool CheckPermTransFormat(const int *perm, const int *perm_transformat, const si
|
|||
}
|
||||
|
||||
int SetOutputShape(int perms_num, const TensorC *input, TensorC *output, int *perm, size_t perm_size, int *out_shape) {
|
||||
if (perms_num == 4) {
|
||||
const int nchw2nhwc[4] = {0, 2, 3, 1};
|
||||
const int nhwc2nchw[4] = {0, 3, 1, 2};
|
||||
const int trans3d[3] = {0, 2, 1};
|
||||
if (input->format_ == Format_NCHW && CheckPermTransFormat(perm, nchw2nhwc, perms_num)) {
|
||||
output->format_ = Format_NHWC;
|
||||
} else if (input->format_ == Format_NHWC && CheckPermTransFormat(perm, nhwc2nchw, perms_num)) {
|
||||
output->format_ = Format_NCHW;
|
||||
}
|
||||
// though the perm is 4d in default, the input can be a 3d tensor. The op implementation should be adapted to this.
|
||||
if (input->shape_size_ == 3) {
|
||||
ShapeSet(perm, &perm_size, trans3d, 3);
|
||||
}
|
||||
}
|
||||
// set output shape
|
||||
size_t in_shape_size = input->shape_size_;
|
||||
output->shape_size_ = in_shape_size;
|
||||
|
@ -76,13 +62,6 @@ int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
|
|||
TensorC *output = outputs[0];
|
||||
|
||||
SetDataTypeFormat(output, input);
|
||||
if (parameter->quant_type_ == QuantType_QUANT_WEIGHT) {
|
||||
output->data_type_ = kNumberTypeFloat32;
|
||||
}
|
||||
if (!InferFlag(inputs, inputs_size)) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
|
||||
const TensorC *perm_tensor = inputs[1];
|
||||
const int32_t *perm_data = (int32_t *)perm_tensor->data_;
|
||||
const size_t perms_num = (size_t)perm_tensor->shape_[0];
|
||||
|
@ -97,6 +76,28 @@ int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
|
|||
for (size_t i = 0; i < perms_num; i++) {
|
||||
ShapePush(perm, &perm_size, perm_data[i]);
|
||||
}
|
||||
if (perms_num == 4) {
|
||||
const int nchw2nhwc[4] = {0, 2, 3, 1};
|
||||
const int nhwc2nchw[4] = {0, 3, 1, 2};
|
||||
const int trans3d[3] = {0, 2, 1};
|
||||
if (input->format_ == Format_NCHW && CheckPermTransFormat(perm, nchw2nhwc, perms_num)) {
|
||||
output->format_ = Format_NHWC;
|
||||
} else if ((input->format_ == Format_NHWC || input->format_ == Format_KHWC) &&
|
||||
CheckPermTransFormat(perm, nhwc2nchw, perms_num)) {
|
||||
output->format_ = Format_NCHW;
|
||||
}
|
||||
// though the perm is 4d in default, the input can be a 3d tensor. The op implementation should be adapted to this.
|
||||
if (input->shape_size_ == 3) {
|
||||
ShapeSet(perm, &perm_size, trans3d, 3);
|
||||
}
|
||||
}
|
||||
if (parameter->quant_type_ == QuantType_QUANT_WEIGHT) {
|
||||
output->data_type_ = kNumberTypeFloat32;
|
||||
}
|
||||
if (!InferFlag(inputs, inputs_size)) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
|
||||
// set output shape
|
||||
int out_shape[MAX_TRANSPOSE_DIM_SIZE] = {0};
|
||||
SetOutputShape(perms_num, input, output, perm, perm_size, out_shape);
|
||||
|
|
|
@ -247,6 +247,7 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/squeeze_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/transpose_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/add_tensor_array.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/conv1d_weight_expanding_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
|
||||
|
@ -282,6 +283,7 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/common/node_util.cc
|
||||
${LITE_DIR}/tools/common/storage.cc
|
||||
${LITE_DIR}/tools/converter/parser/inputs_adjust.cc
|
||||
${LITE_DIR}/tools/converter/parser/insert_transpose.cc
|
||||
${LITE_DIR}/tools/converter/parser/unused_node_remove_pass.cc
|
||||
${LITE_DIR}/tools/converter/parser/conv1d_inout_adjust.cc
|
||||
${LITE_DIR}/tools/converter/parser/tf_bidirection_gru_cf_fusion.cc
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -284,12 +285,15 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim->GetAttr(ops::kFormat) != nullptr) {
|
||||
if (prim->GetAttr(ops::kFormat) != nullptr && !opt::CheckPrimitiveType(cnode, prim::kPrimResize)) {
|
||||
auto value = prim->GetAttr(ops::kFormat);
|
||||
if (value->isa<mindspore::Int64Imm>()) {
|
||||
data_info->format_ = GetValue<int64_t>(value);
|
||||
}
|
||||
}
|
||||
if (!param_node->has_default()) {
|
||||
data_info->format_ = NHWC;
|
||||
}
|
||||
// attr weightFormat is only used by conv-like ops' second input
|
||||
if ((opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
|
||||
opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) ||
|
||||
|
@ -347,6 +351,31 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy
|
|||
return ret;
|
||||
}
|
||||
|
||||
int SetFormatForCnode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
|
||||
DataInfo *data_info) {
|
||||
data_info->format_ = mindspore::NHWC;
|
||||
auto input_node_prim = GetValueNode<PrimitivePtr>((cnode->input(index)->cast<CNodePtr>()->input(0)));
|
||||
if (input_node_prim->GetAttr(ops::kFormat) != nullptr) {
|
||||
auto value = input_node_prim->GetAttr(ops::kFormat);
|
||||
if (value->isa<mindspore::Int64Imm>()) {
|
||||
data_info->format_ = GetValue<int64_t>(value);
|
||||
}
|
||||
}
|
||||
if (opt::CheckPrimitiveType(cnode->input(index), prim::kPrimTranspose)) {
|
||||
std::vector<int> perm;
|
||||
if (opt::GetTransposePerm(cnode->input(index)->cast<CNodePtr>(), &perm) != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (perm[0] == 0 && perm[1] == 3 && perm[2] == 1 && perm[3] == 2 &&
|
||||
(data_info->format_ == NHWC || data_info->format_ == KHWC)) {
|
||||
data_info->format_ = NCHW;
|
||||
} else if (perm[0] == 0 && perm[1] == 2 && perm[2] == 3 && perm[3] == 1 && data_info->format_ == NCHW) {
|
||||
data_info->format_ = NHWC;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
|
||||
DataInfo *data_info) {
|
||||
MS_ASSERT(cnode != nullptr && data_info != nullptr);
|
||||
|
@ -368,7 +397,11 @@ int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType f
|
|||
}
|
||||
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
|
||||
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
|
||||
data_info->format_ = mindspore::NHWC;
|
||||
auto ret = SetFormatForCnode(cnode, index, fmk_type, train_flag, data_info);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "set format for cnode failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
data_info->data_type_ = type_ptr->type_id();
|
||||
data_info->shape_ = dims;
|
||||
data_info->node_type_ = NodeType_CNode;
|
||||
|
|
|
@ -36,6 +36,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/parser/unused_node_remove_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/parser/conv1d_inout_adjust.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/parser/inputs_adjust.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/parser/insert_transpose.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/import/mindspore_importer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/import/primitive_adjust.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_adjust.cc
|
||||
|
@ -73,6 +74,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/fusion/tf_gelu_fusion.cc
|
||||
../optimizer/fusion/onnx_gelu_fusion.cc
|
||||
../optimizer/fusion/squeeze_fusion.cc
|
||||
../optimizer/fusion/transpose_fusion.cc
|
||||
../optimizer/fisson/eliminate_concat_split.cc
|
||||
../optimizer/fisson/fisson_util.cc
|
||||
../optimizer/fisson/iter_node_outputs.cc
|
||||
|
|
|
@ -67,6 +67,7 @@
|
|||
#include "tools/optimizer/parallel/parallel_pass.h"
|
||||
#include "include/registry/pass_registry.h"
|
||||
#include "tools/optimizer/fisson/multi_conv_split_pass.h"
|
||||
#include "tools/optimizer/fusion/transpose_fusion.h"
|
||||
|
||||
using std::string;
|
||||
namespace mindspore::lite {
|
||||
|
@ -82,6 +83,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
|
|||
if (!config->trainModel) {
|
||||
// remove quantdtype when awaretraining
|
||||
fusion_pm->AddPass(std::make_shared<opt::SqueezeFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TransposeFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ReshapeReshapeFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
|
||||
auto conv_bn_pass = std::make_shared<opt::ConvBatchNormFusion>();
|
||||
|
@ -338,7 +340,7 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
|
|||
|
||||
auto format_pass = std::make_shared<opt::UnifyFormatPass>();
|
||||
format_pass->Init(config->fmk, config->trainModel);
|
||||
if (!format_pass->RunOnlyForShape(old_graph)) {
|
||||
if (!format_pass->Run(old_graph)) {
|
||||
MS_LOG(ERROR) << "Run format pass failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "tools/converter/import/mindir_adjust.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/parser/insert_transpose.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
|
@ -202,6 +203,11 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_MS, flag.trainModel);
|
||||
if (!insert_transpose->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
}
|
||||
if ((status = WeightFormatTransform(func_graph)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/converter/parser/insert_transpose.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_CAFFE;
|
||||
namespace mindspore::lite {
|
||||
|
@ -103,6 +104,11 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag)
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_CAFFE, false);
|
||||
if (!insert_transpose->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
}
|
||||
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
|
|
@ -0,0 +1,511 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/insert_transpose.h"
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include "ops/op_utils.h"
|
||||
#include "src/common/common.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
|
||||
using mindspore::lite::NCHW_SHAPE;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr size_t kNCHWDimNumber = 4;
|
||||
const std::vector<int> NH2NC = {0, 3, 1, 2};
|
||||
const std::vector<int> NC2NH = {0, 2, 3, 1};
|
||||
bool IsSpecialType(const CNodePtr &cnode) {
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || opt::CheckPrimitiveType(cnode, prim::kPrimDepend) ||
|
||||
opt::CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || opt::CheckPrimitiveType(cnode, opt::kPrimMakeTupleV2) ||
|
||||
opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void InsertTranspose::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto prim_node = cnode->input(0);
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
MS_ASSERT(prim != nullptr);
|
||||
auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
|
||||
auto &specify_nchw_op_map = opt::GetNCHWOpMap();
|
||||
if (fmk_type_ == lite::converter::FmkType_TFLITE) {
|
||||
if (specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
|
||||
return;
|
||||
}
|
||||
trans_info->pre_ = opt::kNHWC2NCHW;
|
||||
trans_info->post_ = opt::kNCHW2NHWC;
|
||||
} else if (fmk_type_ == lite::converter::FmkType_TF) {
|
||||
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end() && opt::GetFormat(cnode) == NCHW) {
|
||||
trans_info->pre_ = opt::kNCHW2NHWC;
|
||||
trans_info->post_ = opt::kNHWC2NCHW;
|
||||
}
|
||||
if (specify_nchw_op_map.find(prim->name()) != specify_nchw_op_map.end()) {
|
||||
trans_info->pre_ = opt::kNHWC2NCHW;
|
||||
trans_info->post_ = opt::kNCHW2NHWC;
|
||||
}
|
||||
} else {
|
||||
if (specify_nhwc_op_map.find(prim->name()) != specify_nhwc_op_map.end()) {
|
||||
if (fmk_type_ == lite::converter::FmkType_ONNX && prim->GetAttr(ops::kFormat) != nullptr &&
|
||||
GetValue<int64_t>(prim->GetAttr(ops::kFormat)) == NHWC) {
|
||||
return;
|
||||
}
|
||||
trans_info->pre_ = opt::kNCHW2NHWC;
|
||||
trans_info->post_ = opt::kNHWC2NCHW;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr InsertTranspose::GenNewInputWithoutShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> &perm, bool before, size_t index) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
AnfNodePtr new_input = nullptr;
|
||||
AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode;
|
||||
std::string trans_name =
|
||||
before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post";
|
||||
new_input = opt::GenTransposeNode(func_graph, trans_input_node, perm, trans_name);
|
||||
auto new_input_prim = GetValueNode<PrimitivePtr>(new_input->cast<CNodePtr>()->input(0));
|
||||
if (perm == NC2NH) {
|
||||
new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
|
||||
} else if (perm == NH2NC) {
|
||||
new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
|
||||
}
|
||||
return new_input;
|
||||
}
|
||||
|
||||
STATUS InsertTranspose::GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm,
|
||||
bool before, size_t index) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
AnfNodePtr new_input = nullptr;
|
||||
|
||||
new_input = GenNewInputWithoutShape(func_graph, cnode, perm, before, index);
|
||||
if (new_input == nullptr) {
|
||||
MS_LOG(ERROR) << "generate a transpose node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (new_input == cnode->input(index) || new_input == cnode) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
manager = Manage(func_graph, true);
|
||||
}
|
||||
MS_ASSERT(manager != nullptr);
|
||||
auto tr = manager->Transact();
|
||||
if (before) {
|
||||
tr.SetEdge(cnode, index, new_input);
|
||||
tr.Commit();
|
||||
} else {
|
||||
func_graph->manager()->Replace(cnode, new_input);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS InsertTranspose::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> &perm) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
auto prim_node = cnode->input(0);
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
MS_ASSERT(prim != nullptr);
|
||||
auto &specify_nhwc_op_map = opt::GetNHWCOpMap();
|
||||
auto &specify_nchw_op_map = opt::GetNCHWOpMap();
|
||||
if (specify_nhwc_op_map.find(prim->name()) == specify_nhwc_op_map.end() &&
|
||||
specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
|
||||
MS_LOG(ERROR) << "op don't meet nhwc condition.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
std::vector<size_t> insert_index = specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()
|
||||
? specify_nhwc_op_map.at(prim->name())
|
||||
: specify_nchw_op_map.at(prim->name());
|
||||
if (insert_index.empty()) {
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr &&
|
||||
GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) {
|
||||
insert_index.push_back(1);
|
||||
} else {
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
insert_index.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto &index : insert_index) {
|
||||
if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "generate a new input failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS InsertTranspose::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> &perm) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
if (!cnode->abstract()->isa<abstract::AbstractTuple>()) {
|
||||
if (GenNewInput(func_graph, cnode, perm, false) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "generate a new input failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
auto node_users = func_graph->manager()->node_users()[cnode];
|
||||
for (auto &node_user : node_users) {
|
||||
auto post_node = node_user.first;
|
||||
CNodePtr tuple_get_item = nullptr;
|
||||
if (!opt::CheckPrimitiveType(post_node, prim::kPrimTupleGetItem)) {
|
||||
if (!train_flag_) {
|
||||
MS_LOG(ERROR) << "post node is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
} else {
|
||||
tuple_get_item = opt::GenTupleGetItemNode(func_graph, cnode, 0);
|
||||
post_node = tuple_get_item;
|
||||
func_graph->manager()->Replace(cnode, tuple_get_item);
|
||||
}
|
||||
}
|
||||
if (func_graph->manager()->node_users()[post_node].empty()) {
|
||||
continue;
|
||||
}
|
||||
auto post_cnode = post_node->cast<CNodePtr>();
|
||||
if (GenNewInput(func_graph, post_cnode, perm, false) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "generate a new input failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (tuple_get_item != nullptr) {
|
||||
func_graph->manager()->Replace(tuple_get_item, tuple_get_item->input(1));
|
||||
}
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS InsertTranspose::HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
if (fmk_type_ == lite::converter::FmkType_TF || fmk_type_ == lite::converter::FmkType_TFLITE) {
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto node = cnode->input(i);
|
||||
if (!utils::isa<ParameterPtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto param_node = node->cast<ParameterPtr>();
|
||||
if (param_node->has_default()) {
|
||||
continue;
|
||||
}
|
||||
auto abstract_base = param_node->abstract();
|
||||
if (abstract_base == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
|
||||
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
|
||||
if (shape_vector.size() != 4) {
|
||||
continue;
|
||||
}
|
||||
if (func_graph->get_inputs().size() == 1 && fmk_type_ == lite::converter::FmkType_ONNX && shape_vector[3] == 3 &&
|
||||
shape_vector[1] == -1) {
|
||||
continue;
|
||||
}
|
||||
std::vector<int64_t> new_dims = {shape_vector[NCHW_SHAPE::NCHW_N], shape_vector[NCHW_SHAPE::NCHW_H],
|
||||
shape_vector[NCHW_SHAPE::NCHW_W], shape_vector[NCHW_SHAPE::NCHW_C]};
|
||||
abstract_tensor->set_shape(std::make_shared<abstract::Shape>(new_dims));
|
||||
auto trans_cnode = opt::GenTransposeNode(func_graph, param_node, NH2NC, param_node->fullname_with_scope() + "_pre");
|
||||
auto new_input_prim = GetValueNode<PrimitivePtr>(trans_cnode->cast<CNodePtr>()->input(0));
|
||||
new_input_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
|
||||
if (trans_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "generate a transpose node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
func_graph->manager()->Replace(param_node, trans_cnode);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS InsertTranspose::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
opt::TransTypePair trans_info;
|
||||
GetTransNodeFormatType(cnode, &trans_info);
|
||||
if (trans_info.pre_ == opt::kNONE || trans_info.post_ == opt::kNONE) {
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto before_perm = trans_info.pre_ == opt::kNHWC2NCHW ? NH2NC : NC2NH;
|
||||
auto after_perm = trans_info.post_ == opt::kNCHW2NHWC ? NC2NH : NH2NC;
|
||||
if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam) || opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void InsertTranspose::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||
auto sub_inputs = sub_graph->get_inputs();
|
||||
sub_inputs_map_[sub_graph] = sub_inputs;
|
||||
for (auto &node : sub_inputs) {
|
||||
auto param_node = node->cast<ParameterPtr>();
|
||||
MS_ASSERT(param_node != nullptr);
|
||||
auto node_name = node->fullname_with_scope();
|
||||
auto last_underline = node_name.find_last_of("_");
|
||||
node_name = node_name.substr(0, last_underline);
|
||||
last_underline = node_name.find_last_of("_");
|
||||
auto index = std::stoi(node_name.substr(last_underline + 1)) + 3;
|
||||
param_node->set_abstract(opt::GetCNodeInputAbstract(cnode, index)->Clone());
|
||||
if (utils::isa<CNodePtr>(cnode->input(index))) {
|
||||
ShapeVector shape_vec = {-1};
|
||||
auto out_cnode = cnode->input(index)->cast<CNodePtr>();
|
||||
MS_ASSERT(trans_cnode != nullptr);
|
||||
auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0));
|
||||
if (out_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(opt::kInferDone))) {
|
||||
param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vec));
|
||||
}
|
||||
} else {
|
||||
lite::DataInfo data_info;
|
||||
if (utils::isa<ParameterPtr>(cnode->input(index))) {
|
||||
if (cnode->input(index)->cast<ParameterPtr>()->has_default()) {
|
||||
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info);
|
||||
if (status != lite::RET_OK) {
|
||||
continue;
|
||||
}
|
||||
ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
|
||||
if (data_info.data_.empty()) {
|
||||
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec));
|
||||
} else {
|
||||
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec,
|
||||
data_info.data_.data(), data_info.data_.size()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InsertTranspose::ResetSubGraphInput() {
|
||||
for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) {
|
||||
auto &sub_graph = iter->first;
|
||||
auto &sub_inputs = iter->second;
|
||||
auto manager = sub_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
for (auto &sub_input : sub_inputs) {
|
||||
auto param_node = sub_graph->add_parameter();
|
||||
MS_ASSERT(param_node != nullptr);
|
||||
param_node->set_abstract(sub_input->abstract()->Clone());
|
||||
param_node->set_name(sub_input->fullname_with_scope());
|
||||
manager->Replace(sub_input, param_node);
|
||||
auto sub_param_input = sub_input->cast<ParameterPtr>();
|
||||
MS_ASSERT(sub_param_input != nullptr);
|
||||
sub_param_input->set_default_param(nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InsertTranspose::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||
auto return_node = sub_graph->get_return();
|
||||
auto origin_input = return_node->inputs();
|
||||
lite::RemoveIfDepend(return_node);
|
||||
lite::RemoveIfMakeTuple(return_node);
|
||||
for (size_t i = 1; i < return_node->size(); ++i) {
|
||||
if (!opt::CheckPrimitiveType(return_node->input(i), prim::kPrimTranspose)) {
|
||||
continue;
|
||||
}
|
||||
auto node_name = return_node->input(i)->fullname_with_scope();
|
||||
if (node_name.substr(node_name.size() - 5) != "_post") {
|
||||
continue;
|
||||
}
|
||||
auto trans_cnode = return_node->input(i)->cast<CNodePtr>();
|
||||
MS_ASSERT(trans_cnode != nullptr);
|
||||
auto trans_input = trans_cnode->input(1);
|
||||
auto trans_input_name = trans_input->fullname_with_scope();
|
||||
if (utils::isa<ParameterPtr>(trans_input)) {
|
||||
trans_input->cast<ParameterPtr>()->set_name(node_name);
|
||||
} else if (utils::isa<CNodePtr>(trans_input)) {
|
||||
trans_input->cast<CNodePtr>()->set_fullname_with_scope(node_name);
|
||||
}
|
||||
trans_input_name = trans_input_name.substr(0, trans_input_name.find_last_of("_")) + "_cnode";
|
||||
trans_cnode->set_fullname_with_scope(trans_input_name);
|
||||
}
|
||||
return_node->set_inputs(origin_input);
|
||||
}
|
||||
|
||||
void InsertTranspose::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||
auto return_node = sub_graph->get_return();
|
||||
auto origin_inputs = return_node->inputs();
|
||||
lite::RemoveIfDepend(return_node);
|
||||
lite::RemoveIfMakeTuple(return_node);
|
||||
AbstractBasePtrList abstract_list;
|
||||
bool infer_done = true;
|
||||
for (size_t i = 1; i < return_node->size(); ++i) {
|
||||
auto abstract_base = opt::GetCNodeInputAbstract(return_node, i);
|
||||
MS_ASSERT(abstract_base != nullptr);
|
||||
abstract_list.emplace_back(abstract_base->Clone());
|
||||
auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
|
||||
MS_ASSERT(abstract_tensor != nullptr);
|
||||
auto shape_ptr = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape());
|
||||
MS_ASSERT(shape_ptr != nullptr);
|
||||
auto shape = shape_ptr->shape();
|
||||
if (std::find(shape.begin(), shape.end(), -1) != shape.end()) {
|
||||
infer_done = false;
|
||||
}
|
||||
if (utils::isa<CNodePtr>(return_node->input(i))) {
|
||||
auto input_cnode = return_node->input(i)->cast<CNodePtr>();
|
||||
if (opt::CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
|
||||
input_cnode = input_cnode->input(1)->cast<CNodePtr>();
|
||||
}
|
||||
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
|
||||
if (input_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(opt::kInferDone))) {
|
||||
infer_done = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return_node->set_inputs(origin_inputs);
|
||||
if (utils::isa<abstract::AbstractTuplePtr>(cnode->abstract())) {
|
||||
cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
} else {
|
||||
if (abstract_list.size() != 1) {
|
||||
MS_LOG(ERROR) << "cnode output is invalid.";
|
||||
}
|
||||
cnode->set_abstract(abstract_list.front());
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
prim->AddAttr(opt::kInferDone, MakeValue<bool>(infer_done));
|
||||
}
|
||||
|
||||
bool InsertTranspose::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
|
||||
auto manager = Manage(func_graph, true);
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "manager is nullptr.";
|
||||
return false;
|
||||
}
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
int status;
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsSpecialType(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (main_graph) {
|
||||
status = HandleGraphInput(func_graph, cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphInput(cnode, sub_func_graph);
|
||||
(void)BasicProcess(sub_func_graph, false);
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphInput(cnode, sub_func_graph);
|
||||
(void)BasicProcess(sub_func_graph, false);
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
SetSubGraphAbstract(cnode, sub_func_graph);
|
||||
continue;
|
||||
}
|
||||
status = HandleGraphNode(func_graph, cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool InsertTranspose::ResetFuncGraph(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto manager = Manage(func_graph, true);
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "manager is nullptr.";
|
||||
return false;
|
||||
}
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim->GetAttr(opt::kInferDone) != nullptr) {
|
||||
prim->EraseAttr(opt::kInferDone);
|
||||
}
|
||||
if (opt::CheckPrimitiveType(node, prim::kPrimIf) || opt::CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
return false;
|
||||
}
|
||||
(void)ResetFuncGraph(sub_func_graph);
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
|
||||
if (sub_func_graph == nullptr) {
|
||||
return false;
|
||||
}
|
||||
(void)ResetFuncGraph(sub_func_graph);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool InsertTranspose::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(node);
|
||||
if (prim == nullptr) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// insert transpose for some ops whose format must be NHWC, which is depend on framework.
|
||||
// In this process, tranpose can be fused, which the original graph may not be able to restored.
|
||||
if (!BasicProcess(func_graph, true)) {
|
||||
MS_LOG(ERROR) << "run framework transpose unify failed.";
|
||||
return false;
|
||||
}
|
||||
ResetSubGraphInput();
|
||||
return true;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "utils/utils.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "tools/anf_exporter/fetch_content.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class InsertTranspose {
|
||||
public:
|
||||
InsertTranspose(FmkType fmk_type, bool train_flag) : fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
~InsertTranspose() = default;
|
||||
bool Run(const FuncGraphPtr &func_graph);
|
||||
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before,
|
||||
size_t index = 0);
|
||||
|
||||
private:
|
||||
AnfNodePtr GenNewInputWithoutShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> &perm, bool before, size_t index);
|
||||
bool ResetFuncGraph(const FuncGraphPtr &func_graph);
|
||||
bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph);
|
||||
void GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info);
|
||||
STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void ResetSubGraphInput();
|
||||
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
FmkType fmk_type_{lite::converter::FmkType_MS};
|
||||
bool train_flag_{false};
|
||||
std::unordered_map<FuncGraphPtr, std::vector<AnfNodePtr>> sub_inputs_map_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_INSERT_TRANSPOSE_H_
|
|
@ -35,6 +35,7 @@
|
|||
#include "tools/converter/parser/onnx/onnx_pad_adjust.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "ops/transpose.h"
|
||||
#include "tools/converter/parser/insert_transpose.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_ONNX;
|
||||
namespace mindspore {
|
||||
|
@ -90,6 +91,11 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_ONNX, false);
|
||||
if (!insert_transpose->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
}
|
||||
if ((status = WeightFormatTransform(all_func_graphs)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "tools/converter/parser/tf/functionalize_control_op_pass.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/parser/insert_transpose.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_TF;
|
||||
namespace mindspore {
|
||||
|
@ -575,6 +576,11 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_TF, false);
|
||||
if (!insert_transpose->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
}
|
||||
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/converter/parser/tflite/tflite_inputs_adjust.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/converter/parser/insert_transpose.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_TFLITE;
|
||||
namespace mindspore::lite {
|
||||
|
@ -104,6 +105,11 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto insert_transpose = std::make_shared<InsertTranspose>(lite::converter::FmkType_TFLITE, false);
|
||||
if (!insert_transpose->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
}
|
||||
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
|
|
@ -26,7 +26,6 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr auto kInferDone = "infer_done";
|
||||
constexpr auto kTransDone = "trans_done";
|
||||
enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE };
|
||||
struct TransTypePair {
|
||||
FormatTransNodeType pre_;
|
||||
|
|
|
@ -0,0 +1,176 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/fusion/transpose_fusion.h"
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "mindspore/core/ops/transpose.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
const std::vector<int> NH2NC = {0, 3, 1, 2};
|
||||
const std::vector<int> NC2NH = {0, 2, 3, 1};
|
||||
} // namespace
|
||||
bool IsBNCNode(const BaseRef &n) {
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
auto anf_node = utils::cast<AnfNodePtr>(n);
|
||||
return CheckPrimitiveType(anf_node, prim::kPrimBatchNorm) ||
|
||||
CheckPrimitiveType(anf_node, prim::kPrimFusedBatchNorm);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
VectorRef TransposeFusion::DefineBNPattern() const {
|
||||
auto transpose_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
|
||||
auto conv_var = std::make_shared<CondVar>(IsConvNode);
|
||||
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
|
||||
VectorRef transpose_conv_ref = VectorRef({transpose_var, conv_var, transpose_param});
|
||||
auto bn_var = std::make_shared<CondVar>(IsBNCNode);
|
||||
auto bn_mean_var = std::make_shared<CondVar>(IsParamNode);
|
||||
auto bn_variable_var = std::make_shared<CondVar>(IsParamNode);
|
||||
auto bn_other_var = std::make_shared<SeqVar>();
|
||||
VectorRef bn_ref = VectorRef({bn_var, transpose_conv_ref, bn_mean_var, bn_variable_var, bn_other_var});
|
||||
return bn_ref;
|
||||
}
|
||||
|
||||
VectorRef TransposeFusion::DefineActivationscalePattern() const {
|
||||
auto transpose_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
|
||||
auto conv_var = std::make_shared<CondVar>(IsConvNode);
|
||||
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
|
||||
VectorRef transpose_conv_ref = VectorRef({transpose_var, conv_var, transpose_param});
|
||||
auto scale_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimScaleFusion>);
|
||||
auto scale_var_1 = std::make_shared<CondVar>(IsParamNode);
|
||||
auto scale_var_2 = std::make_shared<SeqVar>();
|
||||
VectorRef sclae_ref = VectorRef({scale_var, transpose_conv_ref, scale_var_1, scale_var_2});
|
||||
return sclae_ref;
|
||||
}
|
||||
|
||||
VectorRef TransposeFusion::DefineActivationPattern() const {
|
||||
auto transpose_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
|
||||
auto conv_var = std::make_shared<CondVar>(IsConvNode);
|
||||
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
|
||||
VectorRef transpose_conv_ref = VectorRef({transpose_var, conv_var, transpose_param});
|
||||
auto act_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimActivation>);
|
||||
VectorRef act_ref = VectorRef({act_var, transpose_conv_ref});
|
||||
return act_ref;
|
||||
}
|
||||
|
||||
VectorRef TransposeFusion::DefineBiasAddPattern() const {
|
||||
auto transpose_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
|
||||
auto conv_var = std::make_shared<CondVar>(IsConvNode);
|
||||
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
|
||||
VectorRef transpose_conv_ref = VectorRef({transpose_var, conv_var, transpose_param});
|
||||
auto bias_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBiasAdd>);
|
||||
auto bias_param = std::make_shared<CondVar>(IsParamNode);
|
||||
VectorRef act_ref = VectorRef({bias_var, transpose_conv_ref, bias_param});
|
||||
return act_ref;
|
||||
}
|
||||
|
||||
VectorRef TransposeFusion::DefineTransTransPattern() const {
|
||||
auto transpose_var_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
|
||||
auto transpose_var_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
|
||||
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
|
||||
VectorRef trans_trans_ref = VectorRef({transpose_var_2, transpose_var_1, transpose_param});
|
||||
return trans_trans_ref;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, VectorRef> TransposeFusion::DefinePatterns() const {
|
||||
std::unordered_map<std::string, VectorRef> patterns;
|
||||
patterns["BNPatternName"] = DefineBNPattern();
|
||||
patterns["ActivationPatternName"] = DefineActivationPattern();
|
||||
patterns["BiasAddPatternName"] = DefineBiasAddPattern();
|
||||
patterns["ScalePatternName"] = DefineActivationscalePattern();
|
||||
patterns["TransTransPatternName"] = DefineTransTransPattern();
|
||||
return patterns;
|
||||
}
|
||||
|
||||
CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const AnfNodePtr &perm,
|
||||
const std::string &cnode_name) {
|
||||
MS_ASSERT(func_graph != nullptr && input_node != nullptr);
|
||||
auto trans_prim = std::make_shared<ops::Transpose>();
|
||||
MS_ASSERT(trans_prim != nullptr);
|
||||
auto cnode = func_graph->NewCNode(trans_prim, {input_node, perm});
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
cnode->set_fullname_with_scope(cnode_name);
|
||||
auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(2, 1);
|
||||
auto trans_insert_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
trans_insert_prim->AddAttr("quant_params", quant_params_holder);
|
||||
return cnode;
|
||||
}
|
||||
|
||||
AnfNodePtr TransposeFusion::TransTransFusion(const mindspore::FuncGraphPtr &func_graph,
|
||||
const mindspore::AnfNodePtr &node) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
auto trans_cnode_2 = node->cast<CNodePtr>();
|
||||
if (!CheckPrimitiveType(trans_cnode_2, prim::kPrimTranspose) ||
|
||||
!CheckPrimitiveType(trans_cnode_2->input(1), prim::kPrimTranspose)) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<int> post_perm;
|
||||
if (GetTransposePerm(trans_cnode_2, &post_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get tanspose perm failed.";
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<int> pre_perm;
|
||||
auto pre_node = trans_cnode_2->input(1);
|
||||
auto pre_cnode = pre_node->cast<CNodePtr>();
|
||||
if (pre_cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get tanspose perm failed.";
|
||||
return nullptr;
|
||||
}
|
||||
if ((pre_perm == NH2NC && post_perm == NC2NH) || (pre_perm == NC2NH && post_perm == NH2NC)) {
|
||||
return pre_cnode->input(1);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr TransposeFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
|
||||
const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
|
||||
if (pattern_name == "TransTransPatternName") {
|
||||
return TransTransFusion(func_graph, node);
|
||||
}
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
if (CheckIfCNodeIsNull(node->cast<CNodePtr>()) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
auto any_cnode = node->cast<CNodePtr>();
|
||||
const auto transpose_node = any_cnode->input(1);
|
||||
if (CheckIfCNodeIsNull(transpose_node->cast<CNodePtr>()) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
const CNodePtr &transpose_cnode = transpose_node->cast<CNodePtr>();
|
||||
auto perm_node = transpose_cnode->input(2);
|
||||
auto trans_post_node = GenTransposeNode(func_graph, any_cnode, perm_node, any_cnode->fullname_with_scope() + "_post");
|
||||
auto tr = func_graph->manager()->Transact();
|
||||
tr.SetEdge(any_cnode, 1, transpose_cnode->input(1));
|
||||
tr.Commit();
|
||||
return trans_post_node;
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_TRANSPOSE_FUSION_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_FUSION_TRANSPOSE_FUSION_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "tools/optimizer/graph/unify_format_pass.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TransposeFusion : public MultiplePatternProcessPass {
|
||||
public:
|
||||
explicit TransposeFusion(const std::string &name = "transpose_fusion", bool multigraph = true)
|
||||
: MultiplePatternProcessPass(name, multigraph) {}
|
||||
|
||||
~TransposeFusion() override = default;
|
||||
|
||||
std::unordered_map<std::string, VectorRef> DefinePatterns() const override;
|
||||
VectorRef DefineBNPattern() const;
|
||||
VectorRef DefineActivationPattern() const;
|
||||
VectorRef DefineActivationscalePattern() const;
|
||||
VectorRef DefineTransTransPattern() const;
|
||||
VectorRef DefineBiasAddPattern() const;
|
||||
AnfNodePtr TransTransFusion(const mindspore::FuncGraphPtr &func_graph, const mindspore::AnfNodePtr &node) const;
|
||||
AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &,
|
||||
const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_TRANSPOSE_FUSION_H_
|
|
@ -165,6 +165,8 @@ STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
|
|||
}
|
||||
if (ret == lite::RET_OK || ret == lite::RET_INFER_INVALID) {
|
||||
auto set_status = SetCNodeAbstract(cnode, outputs, ret);
|
||||
auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
cnode_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(inputs[0]->format()));
|
||||
if (set_status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "set CNode abstract failed: " << cnode->fullname_with_scope();
|
||||
FreeTensors(&inputs);
|
||||
|
@ -336,6 +338,7 @@ STATUS NodeInferShape::GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite:
|
|||
tensor = GetCNodeTensorListVarInput(data_info);
|
||||
} else {
|
||||
tensor = new (std::nothrow) lite::Tensor(TypeId(data_info.data_type_), data_info.shape_);
|
||||
tensor->set_format((Format)(data_info.format_));
|
||||
}
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new a lite tensor failed";
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/tensor.h"
|
||||
#include "tools/anf_exporter/fetch_content.h"
|
||||
|
@ -41,9 +42,9 @@ class NodeInferShape {
|
|||
bool JudgeOpSupportInfer(const CNodePtr &cnode);
|
||||
std::vector<int> GetInputShape(const CNodePtr &cnode, size_t index);
|
||||
std::vector<int> GetIntVecInput(const CNodePtr &cnode, size_t index);
|
||||
STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *inputs);
|
||||
|
||||
private:
|
||||
STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *inputs);
|
||||
STATUS GetCNodeConstInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *const_ms_inputs);
|
||||
STATUS GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *var_ms_inputs);
|
||||
lite::Tensor *GetCNodeTensorListVarInput(const lite::DataInfo &data_info);
|
||||
|
|
|
@ -32,7 +32,13 @@ constexpr size_t kInputTripleNum = 3;
|
|||
int ProcessInputIsMonad(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
auto first_input = cnode->input(1);
|
||||
if (CheckPrimitiveType(first_input, prim::kPrimTranspose)) {
|
||||
first_input = cnode->input(1)->cast<CNodePtr>()->input(1);
|
||||
}
|
||||
auto second_input = cnode->input(2);
|
||||
if (CheckPrimitiveType(second_input, prim::kPrimTranspose)) {
|
||||
second_input = cnode->input(2)->cast<CNodePtr>()->input(1);
|
||||
}
|
||||
AnfNodePtr must_monad = nullptr;
|
||||
AnfNodePtr not_must_monad = nullptr;
|
||||
if (utils::isa<ValueNode>(first_input)) {
|
||||
|
@ -72,6 +78,12 @@ int ProcessDependencyWithTwoNodes(const FuncGraphPtr &func_graph, const CNodePtr
|
|||
pre_node = cnode->input(2);
|
||||
post_node = cnode->input(1);
|
||||
}
|
||||
if (CheckPrimitiveType(pre_node, prim::kPrimTranspose)) {
|
||||
pre_node = cnode->input(1)->cast<CNodePtr>()->input(1);
|
||||
}
|
||||
if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) {
|
||||
post_node = cnode->input(2)->cast<CNodePtr>()->input(1);
|
||||
}
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
auto node_users = manager->node_users()[pre_node];
|
||||
|
@ -102,6 +114,10 @@ int ProcessInputHaveDependency(const FuncGraphPtr &func_graph, const CNodePtr &c
|
|||
auto make_tuple_prim = NewValueNode(std::make_shared<lite::MakeTuple>());
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
if (CheckPrimitiveType(cnode->input(0), prim::kPrimTranspose)) {
|
||||
manager->Replace(cnode->input(0)->cast<CNodePtr>()->input(0), make_tuple_prim);
|
||||
return RET_OK;
|
||||
}
|
||||
manager->Replace(cnode->input(0), make_tuple_prim);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
|
|
@ -335,6 +335,9 @@ STATUS UnifyFormatPass::GenNewInput(const FuncGraphPtr &func_graph, const CNodeP
|
|||
}
|
||||
}
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
manager = Manage(func_graph, true);
|
||||
}
|
||||
MS_ASSERT(manager != nullptr);
|
||||
auto tr = manager->Transact();
|
||||
if (before) {
|
||||
|
@ -428,6 +431,7 @@ STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const
|
|||
}
|
||||
}
|
||||
status = node_infer_shape_.InferShape(cnode);
|
||||
|
||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "infer shape failed.";
|
||||
return lite::RET_ERROR;
|
||||
|
@ -523,47 +527,8 @@ STATUS UnifyFormatPass::HandleGraphInput(const FuncGraphPtr &func_graph, const C
|
|||
MS_LOG(ERROR) << "infer shape failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
func_graph->manager()->Replace(param_node, trans_cnode);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
auto prim_node = cnode->input(0);
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
MS_ASSERT(prim != nullptr);
|
||||
if (prim->GetAttr(kTransDone) != nullptr && GetValue<bool>(prim->GetAttr(kTransDone))) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
prim->AddAttr(kTransDone, MakeValue<bool>(true));
|
||||
TransTypePair trans_info;
|
||||
GetTransNodeFormatType(cnode, &trans_info);
|
||||
if (trans_info.pre_ == kNONE || trans_info.post_ == kNONE) {
|
||||
if (!need_reset_ && TransTransFusion(func_graph, cnode)) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
auto status = node_infer_shape_.InferShape(cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "infer shape failed: " << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto before_perm = trans_info.pre_ == kNHWC2NCHW ? NH2NC : NC2NH;
|
||||
auto after_perm = trans_info.post_ == kNCHW2NHWC ? NC2NH : NH2NC;
|
||||
if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto status = node_infer_shape_.InferShape(cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "infer shape failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
func_graph->manager()->Replace(param_node, trans_cnode);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
@ -763,58 +728,6 @@ void UnifyFormatPass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraph
|
|||
prim->AddAttr(kInferDone, MakeValue<bool>(infer_done));
|
||||
}
|
||||
|
||||
bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
|
||||
auto manager = Manage(func_graph, true);
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "manager is nullptr.";
|
||||
return false;
|
||||
}
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
int status;
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsSpecialType(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (main_graph && !need_reset_) {
|
||||
status = HandleGraphInput(func_graph, cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphInput(cnode, sub_func_graph);
|
||||
(void)BasicProcess(sub_func_graph, false);
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphInput(cnode, sub_func_graph);
|
||||
(void)BasicProcess(sub_func_graph, false);
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
SetSubGraphAbstract(cnode, sub_func_graph);
|
||||
continue;
|
||||
}
|
||||
status = HandleGraphNode(func_graph, cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name"));
|
||||
|
@ -935,33 +848,6 @@ bool UnifyFormatPass::ResetFuncGraph(const FuncGraphPtr &func_graph) {
|
|||
if (prim->GetAttr(kInferDone) != nullptr) {
|
||||
prim->EraseAttr(kInferDone);
|
||||
}
|
||||
if (prim->GetAttr(kTransDone) != nullptr) {
|
||||
prim->EraseAttr(kTransDone);
|
||||
}
|
||||
if (pre_insert_trans_.find(cnode) != pre_insert_trans_.end()) {
|
||||
manager->Replace(node, cnode->input(1));
|
||||
}
|
||||
if (post_insert_trans_.find(cnode) != post_insert_trans_.end()) {
|
||||
auto cnode_abstract = cnode->abstract();
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(cnode_abstract)) {
|
||||
MS_LOG(ERROR) << "abstract is not abstract tensor.";
|
||||
return false;
|
||||
}
|
||||
auto cnode_abstract_tensor = cnode_abstract->cast<abstract::AbstractTensorPtr>();
|
||||
if (!utils::isa<abstract::ShapePtr>(cnode_abstract_tensor->BuildShape())) {
|
||||
MS_LOG(ERROR) << "shape of abstract tensor should be ShapePtr.";
|
||||
return false;
|
||||
}
|
||||
auto shape_ptr = utils::cast<abstract::ShapePtr>(cnode_abstract_tensor->BuildShape());
|
||||
auto input_abstract = GetCNodeInputAbstract(cnode, 1);
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(input_abstract)) {
|
||||
MS_LOG(ERROR) << "abstract is not abstract tensor.";
|
||||
return false;
|
||||
}
|
||||
auto input_abstract_tensor = input_abstract->cast<abstract::AbstractTensorPtr>();
|
||||
input_abstract_tensor->set_shape(shape_ptr);
|
||||
manager->Replace(node, cnode->input(1));
|
||||
}
|
||||
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
|
@ -1025,18 +911,140 @@ bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << "exist op cannot support infer shape.";
|
||||
return false;
|
||||
}
|
||||
need_reset_ = true;
|
||||
// insert transpose for some ops whose format must be NHWC, which is depend on framework.
|
||||
// In this process, transpose op cannot be fused to restore the original graph.
|
||||
if (!BasicProcess(func_graph, true)) {
|
||||
MS_LOG(ERROR) << "run framework transpose unify failed.";
|
||||
if (!RunNodeInferShape(func_graph)) {
|
||||
MS_LOG(ERROR) << "RunNodeInferShape failed.";
|
||||
return false;
|
||||
}
|
||||
ResetSubGraphInput();
|
||||
// delete insert transpose op and update op output shape.
|
||||
if (!ResetFuncGraph(func_graph)) {
|
||||
MS_LOG(ERROR) << "reset func_graph failed.";
|
||||
return false;
|
||||
ResetFuncGraph(func_graph);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool UnifyFormatPass::RunNodeInferShape(const FuncGraphPtr &func_graph) {
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsSpecialType(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphInput(cnode, sub_func_graph);
|
||||
if (!RunNodeInferShape(sub_func_graph)) {
|
||||
MS_LOG(ERROR) << "subgraph infer shape failed.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphInput(cnode, sub_func_graph);
|
||||
if (!RunNodeInferShape(sub_func_graph)) {
|
||||
MS_LOG(ERROR) << "subgraph infer shape failed.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
SetSubGraphAbstract(cnode, sub_func_graph);
|
||||
continue;
|
||||
}
|
||||
auto status = node_infer_shape_.InferShape(cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "infer shape failed." << cnode->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool UnifyFormatPass::RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
auto prim_node = cnode->input(0);
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
auto &nchw_op = GetNCHWOpMap();
|
||||
if (!utils::isa<CNodePtr>(cnode->input(1))) {
|
||||
return true;
|
||||
}
|
||||
if (utils::isa<CNodePtr>(cnode->input(1))) {
|
||||
auto format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
|
||||
if (nchw_op.find(prim->name()) != nchw_op.end() && format != NCHW) {
|
||||
InsertPreTransNode(func_graph, cnode, {0, 3, 1, 2});
|
||||
InsertPostTransNode(func_graph, cnode, {0, 2, 3, 1});
|
||||
}
|
||||
}
|
||||
{
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimTranspose)) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
manager = Manage(func_graph, true);
|
||||
}
|
||||
auto shape = node_infer_shape_.GetInputShape(cnode, 1);
|
||||
std::vector<int> perm;
|
||||
auto status = GetTransposePerm(cnode, &perm);
|
||||
if (status != RET_OK) {
|
||||
return false;
|
||||
}
|
||||
if (!shape.empty() && shape.size() != perm.size()) {
|
||||
manager->Replace(cnode, cnode->input(1));
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool UnifyFormatPass::DoFixFormat(const FuncGraphPtr &func_graph) {
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsSpecialType(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
|
||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphInput(cnode, sub_func_graph);
|
||||
if (!DoFixFormat(sub_func_graph)) {
|
||||
MS_LOG(ERROR) << "subgraph infer shape failed.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
|
||||
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphInput(cnode, sub_func_graph);
|
||||
if (!DoFixFormat(sub_func_graph)) {
|
||||
MS_LOG(ERROR) << "subgraph infer shape failed.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
|
||||
return false;
|
||||
}
|
||||
SetSubGraphOutput(cnode, sub_func_graph);
|
||||
SetSubGraphAbstract(cnode, sub_func_graph);
|
||||
continue;
|
||||
}
|
||||
if (!RunDoFixFormat(func_graph, cnode)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -1049,22 +1057,23 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) {
|
|||
if (prim == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (prim->GetAttr(kTransDone) != nullptr) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (!JudgeAllOpsCanInfer(func_graph)) {
|
||||
MS_LOG(ERROR) << "exist op cannot support infer shape.";
|
||||
return false;
|
||||
}
|
||||
// insert transpose for some ops whose format must be NHWC, which is depend on framework.
|
||||
// In this process, tranpose can be fused, which the original graph may not be able to restored.
|
||||
if (!BasicProcess(func_graph, true)) {
|
||||
MS_LOG(ERROR) << "run framework transpose unify failed.";
|
||||
if (!RunNodeInferShape(func_graph)) {
|
||||
MS_LOG(ERROR) << "infer shape failed.";
|
||||
return false;
|
||||
}
|
||||
ResetSubGraphInput();
|
||||
// if input format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op.
|
||||
|
||||
if (!DoFixFormat(func_graph)) {
|
||||
MS_LOG(ERROR) << "DoFixFormat failed.";
|
||||
return false;
|
||||
}
|
||||
ResetSubGraphInput();
|
||||
|
||||
if (!DecreaseTransposeForSingleOp(func_graph)) {
|
||||
MS_LOG(ERROR) << "run local trans insert optimizer failed.";
|
||||
return false;
|
||||
|
|
|
@ -45,23 +45,25 @@ class UnifyFormatPass : public Pass {
|
|||
bool RunOnlyForShape(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before,
|
||||
size_t index = 0);
|
||||
bool RunDoFixFormat(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
bool DoFixFormat(const FuncGraphPtr &func_graph);
|
||||
bool RunNodeInferShape(const FuncGraphPtr &func_graph);
|
||||
bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph);
|
||||
bool ResetFuncGraph(const FuncGraphPtr &func_graph);
|
||||
bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph);
|
||||
bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph);
|
||||
bool DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph);
|
||||
bool TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before,
|
||||
size_t index = 0);
|
||||
|
||||
void GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info);
|
||||
STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
std::set<CNodePtr> *visit_transposes);
|
||||
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info);
|
||||
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||
void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
void ResetSubGraphInput();
|
||||
void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
|
|
Loading…
Reference in New Issue