!4578 modify depthwise_conv tflite parser and conv-related model

Merge pull request !4578 from lyvette/tflite_parser
This commit is contained in:
mindspore-ci-bot 2020-08-18 15:47:13 +08:00 committed by Gitee
commit 5c5d7e3602
8 changed files with 32 additions and 19 deletions

View File

@ -28,13 +28,12 @@ TEST_F(TestTfliteParserConv, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
}
TEST_F(TestTfliteParserConv, AttrValue) {
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr);
auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D();
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConv2D(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsConv2D();
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->group, 1);
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);

View File

@ -28,13 +28,12 @@ TEST_F(TestTfliteParserDeConv, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DeConv2D) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_DeConv2D) << "wrong Op Type";
}
TEST_F(TestTfliteParserDeConv, AttrValue) {
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDeConv2D(), nullptr);
auto val = meta_graph->nodes.at(1)->primitive->value.AsDeConv2D();
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDeConv2D(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsDeConv2D();
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->group, 1);
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);

View File

@ -28,13 +28,12 @@ TEST_F(TestTfliteParserDepthwiseConv1, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
}
TEST_F(TestTfliteParserDepthwiseConv1, AttrValue) {
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr);
auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D();
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConv2D(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsConv2D();
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->group, 0);
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
@ -64,13 +63,12 @@ TEST_F(TestTfliteParserDepthwiseConv2, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DepthwiseConv2D) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_DepthwiseConv2D) << "wrong Op Type";
}
TEST_F(TestTfliteParserDepthwiseConv2, AttrValue) {
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D(), nullptr);
auto val = meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D();
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthwiseConv2D(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsDepthwiseConv2D();
ASSERT_EQ(val->format, schema::Format_NHWC);
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
ASSERT_EQ(val->hasBias, true);

View File

@ -221,12 +221,29 @@ STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) {
if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
auto attr = op->primitive->value.AsDepthwiseConv2D();
if (attr->channelMultiplier > 1) {
// update attr
std::unique_ptr<schema::Conv2DT> conv_attr(new schema::Conv2DT);
// get channel attr
if (op->inputIndex.empty()) {
MS_LOG(ERROR) << "the input of DepthwiseConv2D is null";
return RET_NULL_PTR;
}
auto data_id = op->inputIndex[0];
if (sub_graph->allTensors.size() <= data_id) {
MS_LOG(ERROR) << "the number of allTensors is less than " << data_id;
return RET_ERROR;
}
auto &data_tensor = sub_graph->allTensors.at(data_id);
if (data_tensor == nullptr) {
MS_LOG(ERROR) << "the data tensor is null";
return RET_NULL_PTR;
}
auto data_shape = data_tensor->dims;
conv_attr->channelIn = data_shape[3];
conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier;
// update attr
conv_attr->group = 0;
conv_attr->format = attr->format;
conv_attr->channelIn = attr->channelIn;
conv_attr->channelOut = attr->channelIn * attr->channelMultiplier;
conv_attr->kernelH = attr->kernelH;
conv_attr->kernelW = attr->kernelW;
conv_attr->strideH = attr->strideH;