!20688 [LITE] fix bug

Merge pull request !20688 from yefeng/134-fix_bug
This commit is contained in:
i-robot 2021-07-23 01:31:24 +00:00 committed by Gitee
commit d30a0d295c
10 changed files with 30 additions and 22 deletions

View File

@ -27,6 +27,7 @@
namespace mindspore {
namespace lite {
namespace {
constexpr int kNumWeightIndex = 2;
constexpr size_t kTensorListMinSize = 3 * sizeof(int32_t);
static const std::unordered_map<int, int> TypeToTypeMap = {
{kNumberTypeInt, kNumberTypeInt32}, {kNumberTypeUInt, kNumberTypeUInt32}, {kNumberTypeFloat, kNumberTypeFloat32}};
@ -298,7 +299,7 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
if ((opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) ||
opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) &&
(index == 2 && prim->GetAttr(ops::kFormat) != nullptr)) {
(index == kNumWeightIndex && prim->GetAttr(ops::kFormat) != nullptr)) {
data_info->format_ = mindspore::KHWC;
}
@ -326,7 +327,7 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy
MS_ASSERT(prim != nullptr);
if (value->isa<tensor::Tensor>()) {
ret = FetchFromTensorValue(value_node, prim, fmk_type, train_flag, data_info);
if (index == 2 && prim->GetAttr(ops::kFormat) != nullptr) {
if (index == kNumWeightIndex && prim->GetAttr(ops::kFormat) != nullptr) {
data_info->format_ = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
}
} else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) {
@ -366,6 +367,9 @@ int SetFormatForCnode(const CNodePtr &cnode, size_t index, converter::FmkType fm
if (opt::GetTransposePerm(cnode->input(index)->cast<CNodePtr>(), &perm) != RET_OK) {
return RET_ERROR;
}
if (perm.size() < 4) {
return RET_OK;
}
if (perm[0] == 0 && perm[1] == 3 && perm[2] == 1 && perm[3] == 2 &&
(data_info->format_ == NHWC || data_info->format_ == KHWC)) {
data_info->format_ = NCHW;

View File

@ -39,6 +39,7 @@ constexpr size_t kInitialSize = 1024;
constexpr int kMainGraphIndex = 0;
constexpr int kCallInputMinSize = 1;
constexpr int kSwitchInputMinSize = 3;
constexpr int kNumDim2 = 2;
void FreeTensors(std::vector<Tensor *> *input_tensors, std::vector<Tensor *> *output_tensors) {
if (input_tensors == nullptr) {
@ -82,9 +83,9 @@ void ConvertTensorList(MetaGraphT *graph, uint32_t index, bool *convert_succ, st
return;
}
for (int j = 0; j < data[1]; ++j) {
element_shape.push_back(data[j + 2]);
element_shape.push_back(data[j + kNumDim2]);
}
tensor_shape = {data[data[1] + 2]};
tensor_shape = {data[data[1] + kNumDim2]};
}
lite_tensor = std::make_unique<TensorList>(tensor_shape, element_shape);
if (lite_tensor == nullptr) {

View File

@ -138,8 +138,7 @@ STATUS CaffeModelParser::WeightFormatTransform(const FuncGraphPtr &graph) {
MS_LOG(ERROR) << "weight node must param value";
return RET_OK;
}
lite::STATUS status;
status = HardCodeCaffe(conv_cnode, tensor_info, graph);
auto status = HardCodeCaffe(conv_cnode, tensor_info, graph);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
return RET_ERROR;

View File

@ -29,6 +29,7 @@ namespace mindspore::lite {
namespace {
constexpr size_t kTripleNum = 3;
constexpr size_t kConvWeightIndex = 2;
constexpr int64_t kNumDim2 = 2;
} // namespace
CNodePtr Conv1DInOutAdjust::NewUnsqueezeOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr input_node,
const std::vector<int64_t> &axis) {
@ -75,7 +76,7 @@ lite::STATUS Conv1DInOutAdjust::ExpandFilterShape(const AnfNodePtr &weight_node,
switch (format) {
case schema::Format_NCHW:
case schema::Format_KCHW:
new_shape.insert(new_shape.begin() + 2, 1);
new_shape.insert(new_shape.begin() + kNumDim2, 1);
break;
case schema::Format_NHWC:
case schema::Format_KHWC:

View File

@ -122,8 +122,7 @@ STATUS OnnxModelParser::WeightFormatTransform(const std::set<FuncGraphPtr> &all_
auto weight_node = conv_cnode->input(kConvWeightIndex);
MS_ASSERT(weight_node != nullptr);
auto tensor_info = opt::GetTensorInfo(weight_node);
lite::STATUS status;
status = HardCodeONNX(conv_cnode, tensor_info, graph);
auto status = HardCodeONNX(conv_cnode, tensor_info, graph);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
return RET_ERROR;

View File

@ -28,13 +28,15 @@
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::lite {
namespace {
constexpr size_t kNumWeightIndex = 2;
}
void GetAllFuncGraph(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
if (all_func_graphs->find(func_graph) == all_func_graphs->end()) {
all_func_graphs->insert(func_graph);
} else {
return;
}
auto nodes = func_graph->nodes();
for (auto &node : nodes) {
if (IsValueNode<FuncGraph>(node)) {
@ -221,7 +223,7 @@ int TransposeInsertForWeightConst(const FuncGraphPtr &graph, const CNodePtr &con
prim->AddAttr("quant_params", std::make_shared<QuantParamHolder>(1, 1));
auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node});
transpose_node->set_fullname_with_scope(weight_node->fullname_with_scope() + "_const_post");
conv_node->set_input(2, transpose_node);
conv_node->set_input(kNumWeightIndex, transpose_node);
return lite::RET_OK;
}
@ -244,5 +246,4 @@ int HandleWeightConst(const FuncGraphPtr &graph, const CNodePtr &conv_node, cons
}
return status;
}
} // namespace mindspore::lite

View File

@ -608,8 +608,7 @@ STATUS TFModelParser::WeightFormatTransform(const FuncGraphPtr &graph) {
auto weight_node = conv_cnode->input(kConvWeightIndex);
MS_ASSERT(weight_node != nullptr);
auto tensor_info = opt::GetTensorInfo(weight_node);
lite::STATUS status;
status = HardCodeTF(conv_cnode, tensor_info, graph);
auto status = HardCodeTF(conv_cnode, tensor_info, graph);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
return RET_ERROR;
@ -647,7 +646,6 @@ STATUS TFModelParser::HardCodeTF(const CNodePtr &conv_node, const tensor::Tensor
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
weight_src_format = schema::Format::Format_HWKC;
}
} else if (opt::CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
weight_src_format = schema::Format::Format_HWCK;

View File

@ -134,8 +134,7 @@ STATUS TfliteModelParser::WeightFormatTransform(const FuncGraphPtr &graph) {
auto weight_node = conv_cnode->input(kConvWeightIndex);
MS_ASSERT(weight_node != nullptr);
auto tensor_info = opt::GetTensorInfo(weight_node);
lite::STATUS status;
status = HardCodeTflite(conv_cnode, tensor_info, graph);
auto status = HardCodeTflite(conv_cnode, tensor_info, graph);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
return RET_ERROR;

View File

@ -28,6 +28,8 @@ constexpr size_t kConvWeightIndex = 2;
constexpr size_t kConvBiasIndex = 3;
constexpr size_t kConvNoBiasLen = 3;
constexpr size_t kConvWithBiasLen = 4;
constexpr size_t kNumDim1 = 1;
constexpr size_t kNumDim2 = 2;
int GetOutChannels(const CNodePtr &conv_node) {
MS_ASSERT(conv_node != nullptr);
auto value_node = conv_node->input(0);
@ -72,7 +74,7 @@ void GenerateNewWeightConv2DTranspose(float *dst_weight, const float *scale_weig
MS_ASSERT(group > 0);
auto weight_data = reinterpret_cast<float *>(weight_tensor->data_c());
auto cin_group = weight_tensor->shape()[0] / group;
int area_size = weight_tensor->shape()[1] * weight_tensor->shape()[2];
int area_size = weight_tensor->shape()[kNumDim2] * weight_tensor->shape()[kNumDim2];
for (int k = 0; k < cin_group; ++k) {
for (int j = 0; j < area_size; j++) {
for (int i = 0; i < kernel_num; ++i) {

View File

@ -21,8 +21,12 @@
namespace mindspore::opt {
namespace {
constexpr size_t kNumDim0 = 0;
constexpr size_t kNumDim1 = 1;
constexpr size_t kNumDim2 = 2;
constexpr size_t kNumDim3 = 3;
constexpr int kAnfPopulaterInputNumTwo = 2;
}
} // namespace
lite::STATUS UpdateConv2DParamPass::UpdateCommonConv2D(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
@ -55,10 +59,10 @@ lite::STATUS UpdateConv2DParamPass::UpdateCommonConv2D(const CNodePtr &cnode) {
auto default_param = weight_param->default_param();
auto weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(default_param);
auto weight_shape = weight_tensor->shape();
std::vector<int64_t> kernel_size = {weight_shape[1], weight_shape[2]};
std::vector<int64_t> kernel_size = {weight_shape[kNumDim1], weight_shape[kNumDim2]};
conv->set_kernel_size(kernel_size);
conv->set_in_channel(weight_shape[3]);
conv->set_out_channel(weight_shape[0]);
conv->set_in_channel(weight_shape[kNumDim3]);
conv->set_out_channel(weight_shape[kNumDim0]);
return lite::RET_OK;
}