!24351 [lite]fix the bug of nh2nc pass

Merge pull request !24351 from 徐安越/master3
This commit is contained in:
i-robot 2021-10-11 09:08:29 +00:00 committed by Gitee
commit 8b18dc936e
10 changed files with 135 additions and 109 deletions

View File

@ -169,6 +169,9 @@ int BroadcastToInferShape(const TensorC *const *inputs, size_t inputs_size, Tens
}
} else {
const TensorC *shape_tensor = inputs[1];
if (shape_tensor->data_ == NULL) {
return NNACL_INFER_INVALID;
}
dst_shape_size = GetElementNum(shape_tensor);
if (dst_shape_size > MAX_SHAPE_SIZE) {
return NNACL_INPUT_TENSOR_ERROR;

View File

@ -27,37 +27,33 @@ int HandleTwoInputs(const TensorC *const *inputs, ResizeParameter *param) {
return NNACL_INFER_INVALID;
}
int shape_size = GetElementNum(shape_tensor);
void *origin_data = shape_tensor->data_;
if (origin_data == NULL) {
return NNACL_INFER_INVALID;
}
switch (shape_size) {
case 2:
case 4: {
int height_index = 0;
int width_index = 1;
if (shape_size == 4) {
height_index = kNHWC_H;
width_index = kNHWC_W;
}
if (shape_tensor->data_type_ == kNumberTypeInt32) {
int32_t *data = (int32_t *)(shape_tensor->data_);
if (data == NULL) {
return NNACL_INFER_INVALID;
}
if (GetElementNum(shape_tensor) < 4) {
return NNACL_ERR;
}
param->new_height_ = data[1];
param->new_width_ = data[2];
int32_t *data = (int32_t *)(origin_data);
param->new_height_ = data[height_index];
param->new_width_ = data[width_index];
} else if (shape_tensor->data_type_ == kNumberTypeFloat32) {
float *data = (float *)(shape_tensor->data_);
if (data == NULL) {
return NNACL_INFER_INVALID;
}
MS_CHECK_INT_MUL_NOT_OVERFLOW((int)(data[1]), GetHeight(input), NNACL_ERRCODE_MUL_OVERFLOW);
MS_CHECK_INT_MUL_NOT_OVERFLOW((int)(data[2]), GetWidth(input), NNACL_ERRCODE_MUL_OVERFLOW);
param->new_height_ = round(data[1] * GetHeight(input));
param->new_width_ = round(data[2] * GetWidth(input));
float *data = (float *)(origin_data);
MS_CHECK_INT_MUL_NOT_OVERFLOW((int)(data[height_index]), GetHeight(input), NNACL_ERRCODE_MUL_OVERFLOW);
MS_CHECK_INT_MUL_NOT_OVERFLOW((int)(data[width_index]), GetWidth(input), NNACL_ERRCODE_MUL_OVERFLOW);
param->new_height_ = round(data[height_index] * GetHeight(input));
param->new_width_ = round(data[width_index] * GetWidth(input));
} else if (shape_tensor->data_type_ == kNumberTypeFloat16) {
uint16_t *data = (uint16_t *)(shape_tensor->data_);
if (data == NULL) {
return NNACL_INFER_INVALID;
}
float scale_height = ShortToFloat32(data[1]);
float scale_width = ShortToFloat32(data[2]);
float scale_height = ShortToFloat32(data[height_index]);
float scale_width = ShortToFloat32(data[width_index]);
MS_CHECK_INT_MUL_NOT_OVERFLOW(scale_height, GetHeight(input), NNACL_ERRCODE_MUL_OVERFLOW);
MS_CHECK_INT_MUL_NOT_OVERFLOW(scale_width, GetWidth(input), NNACL_ERRCODE_MUL_OVERFLOW);
param->new_height_ = round(scale_height * GetHeight(input));
@ -65,23 +61,11 @@ int HandleTwoInputs(const TensorC *const *inputs, ResizeParameter *param) {
}
break;
}
case 2: {
int32_t *data = (int32_t *)(shape_tensor->data_);
if (data == NULL) {
return NNACL_INFER_INVALID;
}
param->new_height_ = data[0];
param->new_width_ = data[1];
break;
}
case 1: {
// caffe zoom_factor
int scale;
if (shape_tensor->data_type_ == kNumberTypeInt32) {
int *data = (int *)(shape_tensor->data_);
if (data == NULL) {
return NNACL_INFER_INVALID;
}
int *data = (int *)(origin_data);
scale = data[0];
} else {
return NNACL_ERR;

View File

@ -19,7 +19,7 @@
int TopKInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 1, 2, 2);
int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 2);
if (check_ret != NNACL_OK) {
return check_ret;
}

View File

@ -20,7 +20,11 @@
#include <string>
#include <vector>
#include <unordered_map>
#include "ops/transpose.h"
#include "ops/adam.h"
#include "ops/apply_momentum.h"
#include "ops/fusion/conv2d_fusion.h"
#include "ops/fusion/conv2d_transpose_fusion.h"
#include "ops/sgd.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/parser/conv1d_inout_adjust.h"
#include "tools/converter/parser/inputs_adjust.h"

View File

@ -232,7 +232,11 @@ STATUS UnifyFormatToNHWC::ConvertOnnxResizeForConstShape(const FuncGraphPtr &fun
std::vector<float> new_shape;
MS_CHECK_TRUE_MSG(!shape_tensor->shape().empty(), RET_NULL_PTR, "out of range.");
if (shape_tensor->shape().at(0) == kNumGatherIndiceSize_4) {
new_shape = {shape_data[kNumIndex_0], shape_data[kNumIndex_2], shape_data[kNumIndex_3], shape_data[kNumIndex_1]};
if (shape_data[kNumIndex_0] != 1 || shape_data[kNumIndex_1] != 1) {
MS_LOG(ERROR) << "Op resize don't support, which N dimension and C dimension is not 1.";
return RET_NOT_SUPPORT;
}
new_shape = {shape_data[kNumIndex_2], shape_data[kNumIndex_3]};
} else if (shape_tensor->shape().at(0) == kNumGatherIndiceSize_2) {
return RET_OK;
} else {
@ -240,7 +244,7 @@ STATUS UnifyFormatToNHWC::ConvertOnnxResizeForConstShape(const FuncGraphPtr &fun
}
auto new_shape_node = func_graph->add_parameter();
MS_CHECK_TRUE_MSG(new_shape_node != nullptr, RET_NULL_PTR, "new_shape_node is nullptr.");
auto tensor_info = CreateTensorInfo(nullptr, 0, shape_tensor->shape(), shape_tensor->data_type());
auto tensor_info = CreateTensorInfo(nullptr, 0, std::vector<int64_t>{kNumInputSize}, shape_tensor->data_type());
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "create tensor info failed.";
return RET_ERROR;
@ -250,7 +254,7 @@ STATUS UnifyFormatToNHWC::ConvertOnnxResizeForConstShape(const FuncGraphPtr &fun
MS_LOG(ERROR) << "data is nullptr";
return RET_ERROR;
}
auto status = memcpy_s(new_shape_data, tensor_info->Size(), new_shape.data(), tensor_info->Size());
auto status = memcpy_s(new_shape_data, tensor_info->Size(), new_shape.data(), sizeof(float) * new_shape.size());
if (status != RET_OK) {
MS_LOG(ERROR) << "init parameter from tensor info failed";
return RET_ERROR;
@ -268,7 +272,7 @@ STATUS UnifyFormatToNHWC::ConvertOnnxResizeForVariableShape(const FuncGraphPtr &
auto gather_name = cnode->fullname_with_scope() + "_gather";
auto gather_input = cnode->input(kNumResizeInputShape);
auto abstract = cnode->input(kNumResizeInputShape)->abstract();
std::vector<int> gather_indices = {0, 2, 3, 1}; // NCHW to NHWC
std::vector<int> gather_indices = {kNumIndex_2, kNumIndex_3}; // fetch H and W dimension
auto gather_cnode = opt::GenGatherNode(func_graph, gather_input, gather_indices, gather_name);
if (gather_cnode == nullptr) {
MS_LOG(ERROR) << "create gather cnode failed.";

View File

@ -102,6 +102,35 @@ static const std::unordered_map<std::string, std::vector<size_t>> NHWCOpMap = {
static const std::unordered_map<std::string, std::vector<size_t>> NCHWOpMap = {};
static const std::unordered_map<std::string, std::vector<size_t>> ToNCHWOpMap = {
{ops::kNameAdam, {10}},
{ops::kNameApplyMomentum, {4}},
{ops::kNameAvgPoolFusion, {1}},
{ops::kNameAvgPoolGrad, {}},
{ops::kNameBatchNorm, {1}},
{ops::kNameBatchNormGrad, {1, 2}},
{ops::kNameBatchToSpace, {1}},
{ops::kNameBiasAdd, {1}},
{ops::kNameBiasAddGrad, {1}},
{ops::kNameConv2DBackpropInputFusion, {1}},
{ops::kNameConv2DBackpropFilterFusion, {1, 2}},
{ops::kNameConv2DFusion, {1}},
{ops::kNameConv2dTransposeFusion, {1}},
{ops::kNameDepthToSpace, {1}},
{ops::kNameFusedBatchNorm, {1}},
{ops::kNameInstanceNorm, {1}},
{ops::kNameLRN, {1}},
{ops::kNameMaxPoolFusion, {1}},
{ops::kNameMaxPoolGrad, {}},
{ops::kNamePReLUFusion, {1}},
{ops::kNameResize, {1}},
{ops::kNameResizeGrad, {}},
{ops::kNameROIPooling, {1}},
{ops::kNameSGD, {2}},
{ops::kNameSpaceToBatch, {1}},
{ops::kNameSpaceToBatchND, {1}},
{ops::kNameSpaceToDepth, {1}}};
// a certain op whose input's format is not fixed, bool value determines whether the op has axis attribute or not.
static const std::unordered_map<std::string, bool> DynamicFormatOpList = {
{ops::kNameAddN, false}, {ops::kNameCrop, true}, {ops::kNameSplit, true},
@ -113,6 +142,7 @@ static const std::unordered_map<std::string, bool> DynamicFormatOpList = {
const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap() { return NHWCOpMap; }
const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap() { return NCHWOpMap; }
const std::unordered_map<std::string, std::vector<size_t>> &GetToNCHWOpMap() { return ToNCHWOpMap; }
bool IsDynamicFormatOp(const std::string &op_type) {
return DynamicFormatOpList.find(op_type) != DynamicFormatOpList.end();
}

View File

@ -34,6 +34,7 @@ struct TransTypePair {
};
const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap();
const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap();
const std::unordered_map<std::string, std::vector<size_t>> &GetToNCHWOpMap();
const std::vector<std::string> &GetDynamicFormatOpList();
bool IsDynamicFormatOp(const std::string &op_type);
bool IsDynamicFormatOpWithAxis(const std::string &op_type);

View File

@ -26,11 +26,6 @@
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/format_utils.h"
#include "tools/optimizer/graph/infershape_pass.h"
#include "ops/fusion/conv2d_fusion.h"
#include "ops/fusion/conv2d_transpose_fusion.h"
#include "ops/adam.h"
#include "ops/sgd.h"
#include "ops/apply_momentum.h"
using mindspore::converter::FmkType;
namespace mindspore {
@ -66,7 +61,7 @@ class ToFormatBase : public Pass {
protected:
virtual STATUS GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) = 0;
virtual void SetSensitiveOps() { sensitive_ops_ = opt::GetNHWCOpMap(); }
virtual void SetSensitiveOps() { sensitive_ops_ = GetToNCHWOpMap(); }
virtual bool DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input,
const ShapeVector &shape);
virtual bool DecideWhetherInferShapeForNewNode() { return true; }

View File

@ -168,7 +168,8 @@ int ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
if (status != lite::RET_OK) {
return lite::RET_ERROR;
}
if (data_info.shape_.empty() ||
if (data_info.shape_.empty() || (data_info.shape_.size() == 1 && data_info.shape_[0] == 0) ||
std::all_of(data_info.shape_.begin(), data_info.shape_.end(), [](int val) { return val == 1; }) ||
(data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) {
return lite::RET_OK;
}
@ -273,42 +274,6 @@ STATUS DecreaseTransposeAlgo::GenNewInput(const FuncGraphPtr &func_graph, const
return lite::RET_OK;
}
STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto prim_node = cnode->input(0);
MS_CHECK_TRUE_MSG(prim_node != nullptr, lite::RET_ERROR, "prim_node is nullptr");
auto prim = GetValueNode<PrimitivePtr>(prim_node);
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "GetValueNode Failed");
auto &specify_nhwc_op_map = GetNHWCOpMap();
auto &specify_nchw_op_map = GetNCHWOpMap();
if (specify_nhwc_op_map.find(prim->name()) == specify_nhwc_op_map.end() &&
specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
MS_LOG(ERROR) << "op don't meet nhwc condition.";
return lite::RET_ERROR;
}
std::vector<size_t> insert_index = specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()
? specify_nhwc_op_map.at(prim->name())
: specify_nchw_op_map.at(prim->name());
if (insert_index.empty()) {
if (CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr &&
GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) {
insert_index.push_back(1);
} else {
for (size_t i = 1; i < cnode->size(); ++i) {
insert_index.push_back(i);
}
}
}
for (auto &index : insert_index) {
if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
TransTypePair *trans_insert_info) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
@ -329,27 +294,10 @@ STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph,
MS_LOG(ERROR) << "change op attr failed.";
return lite::RET_ERROR;
}
auto before_perm = trans_insert_info->pre_ == kNHWC2NCHW ? kNH2NC : kNC2NH;
for (size_t i = 1; i < cnode->size(); ++i) {
if (IsMonadNode(cnode->input(i))) {
continue;
}
if (CheckPrimitiveType(cnode->input(i), prim::kPrimMakeTuple) ||
CheckPrimitiveType(cnode->input(i), kPrimMakeTupleV2)) {
auto input_make_tuple = cnode->input(i)->cast<CNodePtr>();
MS_ASSERT(input_make_tuple != nullptr);
for (size_t j = 1; j < input_make_tuple->size(); ++j) {
if (GenNewInput(func_graph, input_make_tuple, before_perm, true, j) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
}
continue;
}
if (GenNewInput(func_graph, cnode, before_perm, true, i) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
status = DoPreInsert(func_graph, cnode, trans_insert_info->pre_);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Do pre insert failed.";
return lite::RET_ERROR;
}
status = ModifyCNodeFormat(cnode, trans_insert_info->pre_);
if (status != lite::RET_OK) {
@ -364,6 +312,63 @@ STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph,
return lite::RET_OK;
}
STATUS DecreaseTransposeAlgo::DoPreInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
FormatTransNodeType trans_type) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto abstract = cnode->abstract();
MS_CHECK_TRUE_RET(abstract != nullptr, lite::RET_NULL_PTR);
if (utils::isa<abstract::AbstractTuplePtr>(abstract)) {
auto abstract_tuple = abstract->cast<abstract::AbstractTuplePtr>();
auto abstract_list = abstract_tuple->elements();
MS_CHECK_TRUE_RET(!abstract_list.empty(), lite::RET_OUT_OF_TENSOR_RANGE);
abstract = abstract_list.front();
MS_CHECK_TRUE_RET(abstract != nullptr, lite::RET_NULL_PTR);
}
ShapeVector shape;
if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
MS_LOG(ERROR) << "fetch shape from abstract fauled.";
return lite::RET_ERROR;
}
auto HandleFunc = [this, &shape](const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index,
FormatTransNodeType trans_type) -> STATUS {
auto before_perm = trans_type == kNHWC2NCHW ? kNH2NC : kNC2NH;
if (shape.size() == kInputSizeFour && !cnode->input(index)->isa<CNode>()) {
if (ConvertTensorToNCOrNH(func_graph, cnode, index, fmk_type_, train_flag_, trans_type) != lite::RET_OK) {
MS_LOG(ERROR) << "ConvertTensorToNCOrNH failed.";
return lite::RET_ERROR;
}
} else if (GenNewInput(func_graph, cnode, before_perm, true, index) != lite::RET_OK) {
MS_LOG(ERROR) << "generate a new input failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
};
for (size_t i = 1; i < cnode->size(); ++i) {
MS_CHECK_TRUE_RET(cnode->input(i) != nullptr, lite::RET_NULL_PTR);
if (IsMonadNode(cnode->input(i))) {
continue;
}
if (CheckPrimitiveType(cnode->input(i), prim::kPrimMakeTuple) ||
CheckPrimitiveType(cnode->input(i), kPrimMakeTupleV2)) {
auto input_make_tuple = cnode->input(i)->cast<CNodePtr>();
MS_ASSERT(input_make_tuple != nullptr);
for (size_t j = 1; j < input_make_tuple->size(); ++j) {
MS_CHECK_TRUE_RET(input_make_tuple->input(j) != nullptr, lite::RET_NULL_PTR);
if (HandleFunc(func_graph, input_make_tuple, j, trans_type) != lite::RET_OK) {
MS_LOG(ERROR) << "handle pre insert failed.";
return lite::RET_ERROR;
}
}
continue;
}
if (HandleFunc(func_graph, cnode, i, trans_type) != lite::RET_OK) {
MS_LOG(ERROR) << "handle pre insert failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
STATUS DecreaseTransposeAlgo::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &perm) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);

View File

@ -41,7 +41,6 @@ class DecreaseTransposeAlgo : public Pass {
private:
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before,
size_t index = 0);
bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph);
@ -51,6 +50,7 @@ class DecreaseTransposeAlgo : public Pass {
STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
std::set<CNodePtr> *visit_transposes);
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info);
STATUS DoPreInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type);
int SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
int ResetSubGraphInput();
int SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);