!5065 fix depthwise conv tflite parser bug
Merge pull request !5065 from lyvette/parser
This commit is contained in:
commit
3a16925fa2
|
@ -249,47 +249,52 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph)
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
auto data_shape = data_tensor->dims;
|
||||
conv_attr->channelIn = data_shape[3];
|
||||
conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier;
|
||||
|
||||
// update attr
|
||||
conv_attr->group = 0;
|
||||
conv_attr->format = attr->format;
|
||||
conv_attr->kernelH = attr->kernelH;
|
||||
conv_attr->kernelW = attr->kernelW;
|
||||
conv_attr->strideH = attr->strideH;
|
||||
conv_attr->strideW = attr->strideW;
|
||||
conv_attr->padMode = attr->padMode;
|
||||
conv_attr->padUp = attr->padUp;
|
||||
conv_attr->padDown = attr->padDown;
|
||||
conv_attr->padLeft = attr->padLeft;
|
||||
conv_attr->padRight = attr->padRight;
|
||||
conv_attr->dilateH = attr->dilateH;
|
||||
conv_attr->dilateW = attr->dilateW;
|
||||
conv_attr->hasBias = attr->hasBias;
|
||||
conv_attr->activationType = attr->activationType;
|
||||
if (data_shape[3] == 1) {
|
||||
conv_attr->channelIn = data_shape[3];
|
||||
conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier;
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
op->primitive->value.value = conv_attr.release();
|
||||
// update attr
|
||||
conv_attr->group = 1;
|
||||
conv_attr->format = attr->format;
|
||||
conv_attr->kernelH = attr->kernelH;
|
||||
conv_attr->kernelW = attr->kernelW;
|
||||
conv_attr->strideH = attr->strideH;
|
||||
conv_attr->strideW = attr->strideW;
|
||||
conv_attr->padMode = attr->padMode;
|
||||
conv_attr->padUp = attr->padUp;
|
||||
conv_attr->padDown = attr->padDown;
|
||||
conv_attr->padLeft = attr->padLeft;
|
||||
conv_attr->padRight = attr->padRight;
|
||||
conv_attr->dilateH = attr->dilateH;
|
||||
conv_attr->dilateW = attr->dilateW;
|
||||
conv_attr->hasBias = attr->hasBias;
|
||||
conv_attr->activationType = attr->activationType;
|
||||
|
||||
// update weight
|
||||
auto weight_id = op->inputIndex[1];
|
||||
auto &weight_tensor = sub_graph->allTensors.at(weight_id);
|
||||
if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) {
|
||||
auto status = TransFilterFormat<uint8_t>(weight_tensor.get(), kKHWC2CHWK);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
|
||||
op->primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
op->primitive->value.value = conv_attr.release();
|
||||
|
||||
// update weight
|
||||
auto weight_id = op->inputIndex[1];
|
||||
auto &weight_tensor = sub_graph->allTensors.at(weight_id);
|
||||
if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) {
|
||||
auto status = TransFilterFormat<uint8_t>(weight_tensor.get(), kKHWC2CHWK);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) {
|
||||
auto status = TransFilterFormat<float>(weight_tensor.get(), kKHWC2CHWK);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Trans filter format failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "The dataType of weight tensor is unsupported.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
weight_tensor->format = schema::Format_CHWK;
|
||||
}
|
||||
if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) {
|
||||
auto status = TransFilterFormat<float>(weight_tensor.get(), kKHWC2CHWK);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Trans filter format failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
weight_tensor->format = schema::Format_CHWK;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue