!20688 [LITE] fix bug
Merge pull request !20688 from yefeng/134-fix_bug
This commit is contained in:
commit
d30a0d295c
|
@ -27,6 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr int kNumWeightIndex = 2;
|
||||
constexpr size_t kTensorListMinSize = 3 * sizeof(int32_t);
|
||||
static const std::unordered_map<int, int> TypeToTypeMap = {
|
||||
{kNumberTypeInt, kNumberTypeInt32}, {kNumberTypeUInt, kNumberTypeUInt32}, {kNumberTypeFloat, kNumberTypeFloat32}};
|
||||
|
@ -298,7 +299,7 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
|
|||
if ((opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
|
||||
opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) ||
|
||||
opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) &&
|
||||
(index == 2 && prim->GetAttr(ops::kFormat) != nullptr)) {
|
||||
(index == kNumWeightIndex && prim->GetAttr(ops::kFormat) != nullptr)) {
|
||||
data_info->format_ = mindspore::KHWC;
|
||||
}
|
||||
|
||||
|
@ -326,7 +327,7 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy
|
|||
MS_ASSERT(prim != nullptr);
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
ret = FetchFromTensorValue(value_node, prim, fmk_type, train_flag, data_info);
|
||||
if (index == 2 && prim->GetAttr(ops::kFormat) != nullptr) {
|
||||
if (index == kNumWeightIndex && prim->GetAttr(ops::kFormat) != nullptr) {
|
||||
data_info->format_ = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
|
||||
}
|
||||
} else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) {
|
||||
|
@ -366,6 +367,9 @@ int SetFormatForCnode(const CNodePtr &cnode, size_t index, converter::FmkType fm
|
|||
if (opt::GetTransposePerm(cnode->input(index)->cast<CNodePtr>(), &perm) != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (perm.size() < 4) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (perm[0] == 0 && perm[1] == 3 && perm[2] == 1 && perm[3] == 2 &&
|
||||
(data_info->format_ == NHWC || data_info->format_ == KHWC)) {
|
||||
data_info->format_ = NCHW;
|
||||
|
|
|
@ -39,6 +39,7 @@ constexpr size_t kInitialSize = 1024;
|
|||
constexpr int kMainGraphIndex = 0;
|
||||
constexpr int kCallInputMinSize = 1;
|
||||
constexpr int kSwitchInputMinSize = 3;
|
||||
constexpr int kNumDim2 = 2;
|
||||
|
||||
void FreeTensors(std::vector<Tensor *> *input_tensors, std::vector<Tensor *> *output_tensors) {
|
||||
if (input_tensors == nullptr) {
|
||||
|
@ -82,9 +83,9 @@ void ConvertTensorList(MetaGraphT *graph, uint32_t index, bool *convert_succ, st
|
|||
return;
|
||||
}
|
||||
for (int j = 0; j < data[1]; ++j) {
|
||||
element_shape.push_back(data[j + 2]);
|
||||
element_shape.push_back(data[j + kNumDim2]);
|
||||
}
|
||||
tensor_shape = {data[data[1] + 2]};
|
||||
tensor_shape = {data[data[1] + kNumDim2]};
|
||||
}
|
||||
lite_tensor = std::make_unique<TensorList>(tensor_shape, element_shape);
|
||||
if (lite_tensor == nullptr) {
|
||||
|
|
|
@ -138,8 +138,7 @@ STATUS CaffeModelParser::WeightFormatTransform(const FuncGraphPtr &graph) {
|
|||
MS_LOG(ERROR) << "weight node must param value";
|
||||
return RET_OK;
|
||||
}
|
||||
lite::STATUS status;
|
||||
status = HardCodeCaffe(conv_cnode, tensor_info, graph);
|
||||
auto status = HardCodeCaffe(conv_cnode, tensor_info, graph);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -29,6 +29,7 @@ namespace mindspore::lite {
|
|||
namespace {
|
||||
constexpr size_t kTripleNum = 3;
|
||||
constexpr size_t kConvWeightIndex = 2;
|
||||
constexpr int64_t kNumDim2 = 2;
|
||||
} // namespace
|
||||
CNodePtr Conv1DInOutAdjust::NewUnsqueezeOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr input_node,
|
||||
const std::vector<int64_t> &axis) {
|
||||
|
@ -75,7 +76,7 @@ lite::STATUS Conv1DInOutAdjust::ExpandFilterShape(const AnfNodePtr &weight_node,
|
|||
switch (format) {
|
||||
case schema::Format_NCHW:
|
||||
case schema::Format_KCHW:
|
||||
new_shape.insert(new_shape.begin() + 2, 1);
|
||||
new_shape.insert(new_shape.begin() + kNumDim2, 1);
|
||||
break;
|
||||
case schema::Format_NHWC:
|
||||
case schema::Format_KHWC:
|
||||
|
|
|
@ -122,8 +122,7 @@ STATUS OnnxModelParser::WeightFormatTransform(const std::set<FuncGraphPtr> &all_
|
|||
auto weight_node = conv_cnode->input(kConvWeightIndex);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
auto tensor_info = opt::GetTensorInfo(weight_node);
|
||||
lite::STATUS status;
|
||||
status = HardCodeONNX(conv_cnode, tensor_info, graph);
|
||||
auto status = HardCodeONNX(conv_cnode, tensor_info, graph);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -28,13 +28,15 @@
|
|||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
constexpr size_t kNumWeightIndex = 2;
|
||||
}
|
||||
void GetAllFuncGraph(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
|
||||
if (all_func_graphs->find(func_graph) == all_func_graphs->end()) {
|
||||
all_func_graphs->insert(func_graph);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
||||
auto nodes = func_graph->nodes();
|
||||
for (auto &node : nodes) {
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
|
@ -221,7 +223,7 @@ int TransposeInsertForWeightConst(const FuncGraphPtr &graph, const CNodePtr &con
|
|||
prim->AddAttr("quant_params", std::make_shared<QuantParamHolder>(1, 1));
|
||||
auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node});
|
||||
transpose_node->set_fullname_with_scope(weight_node->fullname_with_scope() + "_const_post");
|
||||
conv_node->set_input(2, transpose_node);
|
||||
conv_node->set_input(kNumWeightIndex, transpose_node);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
|
@ -244,5 +246,4 @@ int HandleWeightConst(const FuncGraphPtr &graph, const CNodePtr &conv_node, cons
|
|||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -608,8 +608,7 @@ STATUS TFModelParser::WeightFormatTransform(const FuncGraphPtr &graph) {
|
|||
auto weight_node = conv_cnode->input(kConvWeightIndex);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
auto tensor_info = opt::GetTensorInfo(weight_node);
|
||||
lite::STATUS status;
|
||||
status = HardCodeTF(conv_cnode, tensor_info, graph);
|
||||
auto status = HardCodeTF(conv_cnode, tensor_info, graph);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
|
@ -647,7 +646,6 @@ STATUS TFModelParser::HardCodeTF(const CNodePtr &conv_node, const tensor::Tensor
|
|||
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
|
||||
weight_src_format = schema::Format::Format_HWKC;
|
||||
}
|
||||
|
||||
} else if (opt::CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
|
||||
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
|
||||
weight_src_format = schema::Format::Format_HWCK;
|
||||
|
|
|
@ -134,8 +134,7 @@ STATUS TfliteModelParser::WeightFormatTransform(const FuncGraphPtr &graph) {
|
|||
auto weight_node = conv_cnode->input(kConvWeightIndex);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
auto tensor_info = opt::GetTensorInfo(weight_node);
|
||||
lite::STATUS status;
|
||||
status = HardCodeTflite(conv_cnode, tensor_info, graph);
|
||||
auto status = HardCodeTflite(conv_cnode, tensor_info, graph);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -28,6 +28,8 @@ constexpr size_t kConvWeightIndex = 2;
|
|||
constexpr size_t kConvBiasIndex = 3;
|
||||
constexpr size_t kConvNoBiasLen = 3;
|
||||
constexpr size_t kConvWithBiasLen = 4;
|
||||
constexpr size_t kNumDim1 = 1;
|
||||
constexpr size_t kNumDim2 = 2;
|
||||
int GetOutChannels(const CNodePtr &conv_node) {
|
||||
MS_ASSERT(conv_node != nullptr);
|
||||
auto value_node = conv_node->input(0);
|
||||
|
@ -72,7 +74,7 @@ void GenerateNewWeightConv2DTranspose(float *dst_weight, const float *scale_weig
|
|||
MS_ASSERT(group > 0);
|
||||
auto weight_data = reinterpret_cast<float *>(weight_tensor->data_c());
|
||||
auto cin_group = weight_tensor->shape()[0] / group;
|
||||
int area_size = weight_tensor->shape()[1] * weight_tensor->shape()[2];
|
||||
int area_size = weight_tensor->shape()[kNumDim2] * weight_tensor->shape()[kNumDim2];
|
||||
for (int k = 0; k < cin_group; ++k) {
|
||||
for (int j = 0; j < area_size; j++) {
|
||||
for (int i = 0; i < kernel_num; ++i) {
|
||||
|
|
|
@ -21,8 +21,12 @@
|
|||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
constexpr size_t kNumDim0 = 0;
|
||||
constexpr size_t kNumDim1 = 1;
|
||||
constexpr size_t kNumDim2 = 2;
|
||||
constexpr size_t kNumDim3 = 3;
|
||||
constexpr int kAnfPopulaterInputNumTwo = 2;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
lite::STATUS UpdateConv2DParamPass::UpdateCommonConv2D(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
|
@ -55,10 +59,10 @@ lite::STATUS UpdateConv2DParamPass::UpdateCommonConv2D(const CNodePtr &cnode) {
|
|||
auto default_param = weight_param->default_param();
|
||||
auto weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(default_param);
|
||||
auto weight_shape = weight_tensor->shape();
|
||||
std::vector<int64_t> kernel_size = {weight_shape[1], weight_shape[2]};
|
||||
std::vector<int64_t> kernel_size = {weight_shape[kNumDim1], weight_shape[kNumDim2]};
|
||||
conv->set_kernel_size(kernel_size);
|
||||
conv->set_in_channel(weight_shape[3]);
|
||||
conv->set_out_channel(weight_shape[0]);
|
||||
conv->set_in_channel(weight_shape[kNumDim3]);
|
||||
conv->set_out_channel(weight_shape[kNumDim0]);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue