!5065 fix depthwise conv tflite parser bug

Merge pull request !5065 from lyvette/parser
This commit is contained in:
mindspore-ci-bot 2020-08-25 15:20:52 +08:00 committed by Gitee
commit 3a16925fa2
1 changed files with 40 additions and 35 deletions

View File

@ -249,47 +249,52 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph)
return RET_NULL_PTR;
}
auto data_shape = data_tensor->dims;
conv_attr->channelIn = data_shape[3];
conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier;
// update attr
conv_attr->group = 0;
conv_attr->format = attr->format;
conv_attr->kernelH = attr->kernelH;
conv_attr->kernelW = attr->kernelW;
conv_attr->strideH = attr->strideH;
conv_attr->strideW = attr->strideW;
conv_attr->padMode = attr->padMode;
conv_attr->padUp = attr->padUp;
conv_attr->padDown = attr->padDown;
conv_attr->padLeft = attr->padLeft;
conv_attr->padRight = attr->padRight;
conv_attr->dilateH = attr->dilateH;
conv_attr->dilateW = attr->dilateW;
conv_attr->hasBias = attr->hasBias;
conv_attr->activationType = attr->activationType;
if (data_shape[3] == 1) {
conv_attr->channelIn = data_shape[3];
conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier;
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = conv_attr.release();
// update attr
conv_attr->group = 1;
conv_attr->format = attr->format;
conv_attr->kernelH = attr->kernelH;
conv_attr->kernelW = attr->kernelW;
conv_attr->strideH = attr->strideH;
conv_attr->strideW = attr->strideW;
conv_attr->padMode = attr->padMode;
conv_attr->padUp = attr->padUp;
conv_attr->padDown = attr->padDown;
conv_attr->padLeft = attr->padLeft;
conv_attr->padRight = attr->padRight;
conv_attr->dilateH = attr->dilateH;
conv_attr->dilateW = attr->dilateW;
conv_attr->hasBias = attr->hasBias;
conv_attr->activationType = attr->activationType;
// update weight
auto weight_id = op->inputIndex[1];
auto &weight_tensor = sub_graph->allTensors.at(weight_id);
if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) {
auto status = TransFilterFormat<uint8_t>(weight_tensor.get(), kKHWC2CHWK);
if (status != RET_OK) {
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = conv_attr.release();
// update weight
auto weight_id = op->inputIndex[1];
auto &weight_tensor = sub_graph->allTensors.at(weight_id);
if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) {
auto status = TransFilterFormat<uint8_t>(weight_tensor.get(), kKHWC2CHWK);
if (status != RET_OK) {
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
return RET_ERROR;
}
} else if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) {
auto status = TransFilterFormat<float>(weight_tensor.get(), kKHWC2CHWK);
if (status != RET_OK) {
MS_LOG(ERROR) << "Trans filter format failed.";
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "The dataType of weight tensor is unsupported.";
return RET_ERROR;
}
weight_tensor->format = schema::Format_CHWK;
}
if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) {
auto status = TransFilterFormat<float>(weight_tensor.get(), kKHWC2CHWK);
if (status != RET_OK) {
MS_LOG(ERROR) << "Trans filter format failed.";
return RET_ERROR;
}
}
weight_tensor->format = schema::Format_CHWK;
}
}
}