From c87b43b5785e45b7f8291d02b40c3045ab35d8e7 Mon Sep 17 00:00:00 2001 From: lyvette Date: Tue, 25 Aug 2020 09:31:13 +0800 Subject: [PATCH] fix depthwise conv tflite parser bug --- .../parser/tflite/tflite_model_parser.cc | 75 ++++++++++--------- 1 file changed, 40 insertions(+), 35 deletions(-) diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 824ef52768..77f74f1ccc 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -249,47 +249,52 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph) return RET_NULL_PTR; } auto data_shape = data_tensor->dims; - conv_attr->channelIn = data_shape[3]; - conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier; - // update attr - conv_attr->group = 0; - conv_attr->format = attr->format; - conv_attr->kernelH = attr->kernelH; - conv_attr->kernelW = attr->kernelW; - conv_attr->strideH = attr->strideH; - conv_attr->strideW = attr->strideW; - conv_attr->padMode = attr->padMode; - conv_attr->padUp = attr->padUp; - conv_attr->padDown = attr->padDown; - conv_attr->padLeft = attr->padLeft; - conv_attr->padRight = attr->padRight; - conv_attr->dilateH = attr->dilateH; - conv_attr->dilateW = attr->dilateW; - conv_attr->hasBias = attr->hasBias; - conv_attr->activationType = attr->activationType; + if (data_shape[3] == 1) { + conv_attr->channelIn = data_shape[3]; + conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier; - op->primitive->value.type = schema::PrimitiveType_Conv2D; - op->primitive->value.value = conv_attr.release(); + // update attr + conv_attr->group = 1; + conv_attr->format = attr->format; + conv_attr->kernelH = attr->kernelH; + conv_attr->kernelW = attr->kernelW; + conv_attr->strideH = attr->strideH; + conv_attr->strideW = attr->strideW; + conv_attr->padMode = attr->padMode; + conv_attr->padUp = attr->padUp; + conv_attr->padDown = attr->padDown; + conv_attr->padLeft = attr->padLeft; + conv_attr->padRight = attr->padRight; + conv_attr->dilateH = attr->dilateH; + conv_attr->dilateW = attr->dilateW; + conv_attr->hasBias = attr->hasBias; + conv_attr->activationType = attr->activationType; - // update weight - auto weight_id = op->inputIndex[1]; - auto &weight_tensor = sub_graph->allTensors.at(weight_id); - if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) { - auto status = TransFilterFormat(weight_tensor.get(), kKHWC2CHWK); - if (status != RET_OK) { - MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; + op->primitive->value.type = schema::PrimitiveType_Conv2D; + op->primitive->value.value = conv_attr.release(); + + // update weight + auto weight_id = op->inputIndex[1]; + auto &weight_tensor = sub_graph->allTensors.at(weight_id); + if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) { + auto status = TransFilterFormat(weight_tensor.get(), kKHWC2CHWK); + if (status != RET_OK) { + MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; + return RET_ERROR; + } + } else if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) { + auto status = TransFilterFormat(weight_tensor.get(), kKHWC2CHWK); + if (status != RET_OK) { + MS_LOG(ERROR) << "Trans filter format failed."; + return RET_ERROR; + } + } else { + MS_LOG(ERROR) << "The dataType of weight tensor is unsupported."; return RET_ERROR; } + weight_tensor->format = schema::Format_CHWK; } - if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) { - auto status = TransFilterFormat(weight_tensor.get(), kKHWC2CHWK); - if (status != RET_OK) { - MS_LOG(ERROR) << "Trans filter format failed."; - return RET_ERROR; - } - } - weight_tensor->format = schema::Format_CHWK; } } }