forked from mindspore-Ecosystem/mindspore
fix depthwise conv tflite parser bug
This commit is contained in:
parent
86beb6e94b
commit
c87b43b578
|
@ -249,11 +249,13 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph)
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
auto data_shape = data_tensor->dims;
|
||||
|
||||
if (data_shape[3] == 1) {
|
||||
conv_attr->channelIn = data_shape[3];
|
||||
conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier;
|
||||
|
||||
// update attr
|
||||
conv_attr->group = 0;
|
||||
conv_attr->group = 1;
|
||||
conv_attr->format = attr->format;
|
||||
conv_attr->kernelH = attr->kernelH;
|
||||
conv_attr->kernelW = attr->kernelW;
|
||||
|
@ -281,18 +283,21 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph)
|
|||
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) {
|
||||
} else if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) {
|
||||
auto status = TransFilterFormat<float>(weight_tensor.get(), kKHWC2CHWK);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Trans filter format failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "The dataType of weight tensor is unsupported.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
weight_tensor->format = schema::Format_CHWK;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue