fix depthwise conv tflite parser bug

This commit is contained in:
lyvette 2020-08-25 09:31:13 +08:00
parent 86beb6e94b
commit c87b43b578
1 changed files with 40 additions and 35 deletions

View File

@ -249,11 +249,13 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph)
return RET_NULL_PTR;
}
auto data_shape = data_tensor->dims;
if (data_shape[3] == 1) {
conv_attr->channelIn = data_shape[3];
conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier;
// update attr
conv_attr->group = 0;
conv_attr->group = 1;
conv_attr->format = attr->format;
conv_attr->kernelH = attr->kernelH;
conv_attr->kernelW = attr->kernelW;
@ -281,18 +283,21 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph)
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
return RET_ERROR;
}
}
if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) {
} else if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) {
auto status = TransFilterFormat<float>(weight_tensor.get(), kKHWC2CHWK);
if (status != RET_OK) {
MS_LOG(ERROR) << "Trans filter format failed.";
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "The dataType of weight tensor is unsupported.";
return RET_ERROR;
}
weight_tensor->format = schema::Format_CHWK;
}
}
}
}
return RET_OK;
}