forked from mindspore-Ecosystem/mindspore
!24351 [lite]fix the bug of nh2nc pass
Merge pull request !24351 from 徐安越/master3
This commit is contained in:
commit
8b18dc936e
|
@ -169,6 +169,9 @@ int BroadcastToInferShape(const TensorC *const *inputs, size_t inputs_size, Tens
|
|||
}
|
||||
} else {
|
||||
const TensorC *shape_tensor = inputs[1];
|
||||
if (shape_tensor->data_ == NULL) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
dst_shape_size = GetElementNum(shape_tensor);
|
||||
if (dst_shape_size > MAX_SHAPE_SIZE) {
|
||||
return NNACL_INPUT_TENSOR_ERROR;
|
||||
|
|
|
@ -27,37 +27,33 @@ int HandleTwoInputs(const TensorC *const *inputs, ResizeParameter *param) {
|
|||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
int shape_size = GetElementNum(shape_tensor);
|
||||
void *origin_data = shape_tensor->data_;
|
||||
if (origin_data == NULL) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
switch (shape_size) {
|
||||
case 2:
|
||||
case 4: {
|
||||
int height_index = 0;
|
||||
int width_index = 1;
|
||||
if (shape_size == 4) {
|
||||
height_index = kNHWC_H;
|
||||
width_index = kNHWC_W;
|
||||
}
|
||||
if (shape_tensor->data_type_ == kNumberTypeInt32) {
|
||||
int32_t *data = (int32_t *)(shape_tensor->data_);
|
||||
if (data == NULL) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
if (GetElementNum(shape_tensor) < 4) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
param->new_height_ = data[1];
|
||||
param->new_width_ = data[2];
|
||||
int32_t *data = (int32_t *)(origin_data);
|
||||
param->new_height_ = data[height_index];
|
||||
param->new_width_ = data[width_index];
|
||||
} else if (shape_tensor->data_type_ == kNumberTypeFloat32) {
|
||||
float *data = (float *)(shape_tensor->data_);
|
||||
if (data == NULL) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW((int)(data[1]), GetHeight(input), NNACL_ERRCODE_MUL_OVERFLOW);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW((int)(data[2]), GetWidth(input), NNACL_ERRCODE_MUL_OVERFLOW);
|
||||
param->new_height_ = round(data[1] * GetHeight(input));
|
||||
param->new_width_ = round(data[2] * GetWidth(input));
|
||||
float *data = (float *)(origin_data);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW((int)(data[height_index]), GetHeight(input), NNACL_ERRCODE_MUL_OVERFLOW);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW((int)(data[width_index]), GetWidth(input), NNACL_ERRCODE_MUL_OVERFLOW);
|
||||
param->new_height_ = round(data[height_index] * GetHeight(input));
|
||||
param->new_width_ = round(data[width_index] * GetWidth(input));
|
||||
} else if (shape_tensor->data_type_ == kNumberTypeFloat16) {
|
||||
uint16_t *data = (uint16_t *)(shape_tensor->data_);
|
||||
if (data == NULL) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
|
||||
float scale_height = ShortToFloat32(data[1]);
|
||||
float scale_width = ShortToFloat32(data[2]);
|
||||
|
||||
float scale_height = ShortToFloat32(data[height_index]);
|
||||
float scale_width = ShortToFloat32(data[width_index]);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(scale_height, GetHeight(input), NNACL_ERRCODE_MUL_OVERFLOW);
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(scale_width, GetWidth(input), NNACL_ERRCODE_MUL_OVERFLOW);
|
||||
param->new_height_ = round(scale_height * GetHeight(input));
|
||||
|
@ -65,23 +61,11 @@ int HandleTwoInputs(const TensorC *const *inputs, ResizeParameter *param) {
|
|||
}
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
int32_t *data = (int32_t *)(shape_tensor->data_);
|
||||
if (data == NULL) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
param->new_height_ = data[0];
|
||||
param->new_width_ = data[1];
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
// caffe zoom_factor
|
||||
int scale;
|
||||
if (shape_tensor->data_type_ == kNumberTypeInt32) {
|
||||
int *data = (int *)(shape_tensor->data_);
|
||||
if (data == NULL) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
int *data = (int *)(origin_data);
|
||||
scale = data[0];
|
||||
} else {
|
||||
return NNACL_ERR;
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
|
||||
int TopKInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 1, 2, 2);
|
||||
int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 2);
|
||||
if (check_ret != NNACL_OK) {
|
||||
return check_ret;
|
||||
}
|
||||
|
|
|
@ -20,7 +20,11 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include "ops/transpose.h"
|
||||
#include "ops/adam.h"
|
||||
#include "ops/apply_momentum.h"
|
||||
#include "ops/fusion/conv2d_fusion.h"
|
||||
#include "ops/fusion/conv2d_transpose_fusion.h"
|
||||
#include "ops/sgd.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/parser/conv1d_inout_adjust.h"
|
||||
#include "tools/converter/parser/inputs_adjust.h"
|
||||
|
|
|
@ -232,7 +232,11 @@ STATUS UnifyFormatToNHWC::ConvertOnnxResizeForConstShape(const FuncGraphPtr &fun
|
|||
std::vector<float> new_shape;
|
||||
MS_CHECK_TRUE_MSG(!shape_tensor->shape().empty(), RET_NULL_PTR, "out of range.");
|
||||
if (shape_tensor->shape().at(0) == kNumGatherIndiceSize_4) {
|
||||
new_shape = {shape_data[kNumIndex_0], shape_data[kNumIndex_2], shape_data[kNumIndex_3], shape_data[kNumIndex_1]};
|
||||
if (shape_data[kNumIndex_0] != 1 || shape_data[kNumIndex_1] != 1) {
|
||||
MS_LOG(ERROR) << "Op resize don't support, which N dimension and C dimension is not 1.";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
new_shape = {shape_data[kNumIndex_2], shape_data[kNumIndex_3]};
|
||||
} else if (shape_tensor->shape().at(0) == kNumGatherIndiceSize_2) {
|
||||
return RET_OK;
|
||||
} else {
|
||||
|
@ -240,7 +244,7 @@ STATUS UnifyFormatToNHWC::ConvertOnnxResizeForConstShape(const FuncGraphPtr &fun
|
|||
}
|
||||
auto new_shape_node = func_graph->add_parameter();
|
||||
MS_CHECK_TRUE_MSG(new_shape_node != nullptr, RET_NULL_PTR, "new_shape_node is nullptr.");
|
||||
auto tensor_info = CreateTensorInfo(nullptr, 0, shape_tensor->shape(), shape_tensor->data_type());
|
||||
auto tensor_info = CreateTensorInfo(nullptr, 0, std::vector<int64_t>{kNumInputSize}, shape_tensor->data_type());
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "create tensor info failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -250,7 +254,7 @@ STATUS UnifyFormatToNHWC::ConvertOnnxResizeForConstShape(const FuncGraphPtr &fun
|
|||
MS_LOG(ERROR) << "data is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto status = memcpy_s(new_shape_data, tensor_info->Size(), new_shape.data(), tensor_info->Size());
|
||||
auto status = memcpy_s(new_shape_data, tensor_info->Size(), new_shape.data(), sizeof(float) * new_shape.size());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "init parameter from tensor info failed";
|
||||
return RET_ERROR;
|
||||
|
@ -268,7 +272,7 @@ STATUS UnifyFormatToNHWC::ConvertOnnxResizeForVariableShape(const FuncGraphPtr &
|
|||
auto gather_name = cnode->fullname_with_scope() + "_gather";
|
||||
auto gather_input = cnode->input(kNumResizeInputShape);
|
||||
auto abstract = cnode->input(kNumResizeInputShape)->abstract();
|
||||
std::vector<int> gather_indices = {0, 2, 3, 1}; // NCHW to NHWC
|
||||
std::vector<int> gather_indices = {kNumIndex_2, kNumIndex_3}; // fetch H and W dimension
|
||||
auto gather_cnode = opt::GenGatherNode(func_graph, gather_input, gather_indices, gather_name);
|
||||
if (gather_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "create gather cnode failed.";
|
||||
|
|
|
@ -102,6 +102,35 @@ static const std::unordered_map<std::string, std::vector<size_t>> NHWCOpMap = {
|
|||
|
||||
static const std::unordered_map<std::string, std::vector<size_t>> NCHWOpMap = {};
|
||||
|
||||
static const std::unordered_map<std::string, std::vector<size_t>> ToNCHWOpMap = {
|
||||
{ops::kNameAdam, {10}},
|
||||
{ops::kNameApplyMomentum, {4}},
|
||||
{ops::kNameAvgPoolFusion, {1}},
|
||||
{ops::kNameAvgPoolGrad, {}},
|
||||
{ops::kNameBatchNorm, {1}},
|
||||
{ops::kNameBatchNormGrad, {1, 2}},
|
||||
{ops::kNameBatchToSpace, {1}},
|
||||
{ops::kNameBiasAdd, {1}},
|
||||
{ops::kNameBiasAddGrad, {1}},
|
||||
{ops::kNameConv2DBackpropInputFusion, {1}},
|
||||
{ops::kNameConv2DBackpropFilterFusion, {1, 2}},
|
||||
{ops::kNameConv2DFusion, {1}},
|
||||
{ops::kNameConv2dTransposeFusion, {1}},
|
||||
{ops::kNameDepthToSpace, {1}},
|
||||
{ops::kNameFusedBatchNorm, {1}},
|
||||
{ops::kNameInstanceNorm, {1}},
|
||||
{ops::kNameLRN, {1}},
|
||||
{ops::kNameMaxPoolFusion, {1}},
|
||||
{ops::kNameMaxPoolGrad, {}},
|
||||
{ops::kNamePReLUFusion, {1}},
|
||||
{ops::kNameResize, {1}},
|
||||
{ops::kNameResizeGrad, {}},
|
||||
{ops::kNameROIPooling, {1}},
|
||||
{ops::kNameSGD, {2}},
|
||||
{ops::kNameSpaceToBatch, {1}},
|
||||
{ops::kNameSpaceToBatchND, {1}},
|
||||
{ops::kNameSpaceToDepth, {1}}};
|
||||
|
||||
// a certain op whose input's format is not fixed, bool value determines whether the op has axis attribute or not.
|
||||
static const std::unordered_map<std::string, bool> DynamicFormatOpList = {
|
||||
{ops::kNameAddN, false}, {ops::kNameCrop, true}, {ops::kNameSplit, true},
|
||||
|
@ -113,6 +142,7 @@ static const std::unordered_map<std::string, bool> DynamicFormatOpList = {
|
|||
|
||||
const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap() { return NHWCOpMap; }
|
||||
const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap() { return NCHWOpMap; }
|
||||
const std::unordered_map<std::string, std::vector<size_t>> &GetToNCHWOpMap() { return ToNCHWOpMap; }
|
||||
bool IsDynamicFormatOp(const std::string &op_type) {
|
||||
return DynamicFormatOpList.find(op_type) != DynamicFormatOpList.end();
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ struct TransTypePair {
|
|||
};
|
||||
const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap();
|
||||
const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap();
|
||||
const std::unordered_map<std::string, std::vector<size_t>> &GetToNCHWOpMap();
|
||||
const std::vector<std::string> &GetDynamicFormatOpList();
|
||||
bool IsDynamicFormatOp(const std::string &op_type);
|
||||
bool IsDynamicFormatOpWithAxis(const std::string &op_type);
|
||||
|
|
|
@ -26,11 +26,6 @@
|
|||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "tools/optimizer/graph/infershape_pass.h"
|
||||
#include "ops/fusion/conv2d_fusion.h"
|
||||
#include "ops/fusion/conv2d_transpose_fusion.h"
|
||||
#include "ops/adam.h"
|
||||
#include "ops/sgd.h"
|
||||
#include "ops/apply_momentum.h"
|
||||
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore {
|
||||
|
@ -66,7 +61,7 @@ class ToFormatBase : public Pass {
|
|||
|
||||
protected:
|
||||
virtual STATUS GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) = 0;
|
||||
virtual void SetSensitiveOps() { sensitive_ops_ = opt::GetNHWCOpMap(); }
|
||||
virtual void SetSensitiveOps() { sensitive_ops_ = GetToNCHWOpMap(); }
|
||||
virtual bool DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input,
|
||||
const ShapeVector &shape);
|
||||
virtual bool DecideWhetherInferShapeForNewNode() { return true; }
|
||||
|
|
|
@ -168,7 +168,8 @@ int ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
|||
if (status != lite::RET_OK) {
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (data_info.shape_.empty() ||
|
||||
if (data_info.shape_.empty() || (data_info.shape_.size() == 1 && data_info.shape_[0] == 0) ||
|
||||
std::all_of(data_info.shape_.begin(), data_info.shape_.end(), [](int val) { return val == 1; }) ||
|
||||
(data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
@ -273,42 +274,6 @@ STATUS DecreaseTransposeAlgo::GenNewInput(const FuncGraphPtr &func_graph, const
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> &perm) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
auto prim_node = cnode->input(0);
|
||||
MS_CHECK_TRUE_MSG(prim_node != nullptr, lite::RET_ERROR, "prim_node is nullptr");
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_node);
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "GetValueNode Failed");
|
||||
auto &specify_nhwc_op_map = GetNHWCOpMap();
|
||||
auto &specify_nchw_op_map = GetNCHWOpMap();
|
||||
if (specify_nhwc_op_map.find(prim->name()) == specify_nhwc_op_map.end() &&
|
||||
specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()) {
|
||||
MS_LOG(ERROR) << "op don't meet nhwc condition.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
std::vector<size_t> insert_index = specify_nchw_op_map.find(prim->name()) == specify_nchw_op_map.end()
|
||||
? specify_nhwc_op_map.at(prim->name())
|
||||
: specify_nchw_op_map.at(prim->name());
|
||||
if (insert_index.empty()) {
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimResizeGrad) && prim->GetAttr(ops::kMethod) != nullptr &&
|
||||
GetValue<int64_t>(prim->GetAttr(ops::kMethod)) == static_cast<int64_t>(mindspore::ResizeMethod::NEAREST)) {
|
||||
insert_index.push_back(1);
|
||||
} else {
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
insert_index.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto &index : insert_index) {
|
||||
if (GenNewInput(func_graph, cnode, perm, true, index) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "generate a new input failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
TransTypePair *trans_insert_info) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
|
@ -329,27 +294,10 @@ STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph,
|
|||
MS_LOG(ERROR) << "change op attr failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto before_perm = trans_insert_info->pre_ == kNHWC2NCHW ? kNH2NC : kNC2NH;
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
if (IsMonadNode(cnode->input(i))) {
|
||||
continue;
|
||||
}
|
||||
if (CheckPrimitiveType(cnode->input(i), prim::kPrimMakeTuple) ||
|
||||
CheckPrimitiveType(cnode->input(i), kPrimMakeTupleV2)) {
|
||||
auto input_make_tuple = cnode->input(i)->cast<CNodePtr>();
|
||||
MS_ASSERT(input_make_tuple != nullptr);
|
||||
for (size_t j = 1; j < input_make_tuple->size(); ++j) {
|
||||
if (GenNewInput(func_graph, input_make_tuple, before_perm, true, j) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "generate a new input failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (GenNewInput(func_graph, cnode, before_perm, true, i) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "generate a new input failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
status = DoPreInsert(func_graph, cnode, trans_insert_info->pre_);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Do pre insert failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
status = ModifyCNodeFormat(cnode, trans_insert_info->pre_);
|
||||
if (status != lite::RET_OK) {
|
||||
|
@ -364,6 +312,63 @@ STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph,
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS DecreaseTransposeAlgo::DoPreInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
FormatTransNodeType trans_type) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
auto abstract = cnode->abstract();
|
||||
MS_CHECK_TRUE_RET(abstract != nullptr, lite::RET_NULL_PTR);
|
||||
if (utils::isa<abstract::AbstractTuplePtr>(abstract)) {
|
||||
auto abstract_tuple = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
auto abstract_list = abstract_tuple->elements();
|
||||
MS_CHECK_TRUE_RET(!abstract_list.empty(), lite::RET_OUT_OF_TENSOR_RANGE);
|
||||
abstract = abstract_list.front();
|
||||
MS_CHECK_TRUE_RET(abstract != nullptr, lite::RET_NULL_PTR);
|
||||
}
|
||||
ShapeVector shape;
|
||||
if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "fetch shape from abstract fauled.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto HandleFunc = [this, &shape](const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index,
|
||||
FormatTransNodeType trans_type) -> STATUS {
|
||||
auto before_perm = trans_type == kNHWC2NCHW ? kNH2NC : kNC2NH;
|
||||
if (shape.size() == kInputSizeFour && !cnode->input(index)->isa<CNode>()) {
|
||||
if (ConvertTensorToNCOrNH(func_graph, cnode, index, fmk_type_, train_flag_, trans_type) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvertTensorToNCOrNH failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} else if (GenNewInput(func_graph, cnode, before_perm, true, index) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "generate a new input failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
};
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
MS_CHECK_TRUE_RET(cnode->input(i) != nullptr, lite::RET_NULL_PTR);
|
||||
if (IsMonadNode(cnode->input(i))) {
|
||||
continue;
|
||||
}
|
||||
if (CheckPrimitiveType(cnode->input(i), prim::kPrimMakeTuple) ||
|
||||
CheckPrimitiveType(cnode->input(i), kPrimMakeTupleV2)) {
|
||||
auto input_make_tuple = cnode->input(i)->cast<CNodePtr>();
|
||||
MS_ASSERT(input_make_tuple != nullptr);
|
||||
for (size_t j = 1; j < input_make_tuple->size(); ++j) {
|
||||
MS_CHECK_TRUE_RET(input_make_tuple->input(j) != nullptr, lite::RET_NULL_PTR);
|
||||
if (HandleFunc(func_graph, input_make_tuple, j, trans_type) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "handle pre insert failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (HandleFunc(func_graph, cnode, i, trans_type) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "handle pre insert failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS DecreaseTransposeAlgo::InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> &perm) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
|
|
|
@ -41,7 +41,6 @@ class DecreaseTransposeAlgo : public Pass {
|
|||
|
||||
private:
|
||||
STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm);
|
||||
STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before,
|
||||
size_t index = 0);
|
||||
bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph);
|
||||
|
@ -51,6 +50,7 @@ class DecreaseTransposeAlgo : public Pass {
|
|||
STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
std::set<CNodePtr> *visit_transposes);
|
||||
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info);
|
||||
STATUS DoPreInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type);
|
||||
int SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
int ResetSubGraphInput();
|
||||
int SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph);
|
||||
|
|
Loading…
Reference in New Issue