forked from mindspore-Ecosystem/mindspore
!4578 modify depthwise_conv tflite parser and conv-related model
Merge pull request !4578 from lyvette/tflite_parser
This commit is contained in:
commit
5c5d7e3602
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -28,13 +28,12 @@ TEST_F(TestTfliteParserConv, OpType) {
|
|||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
|
||||
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserConv, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr);
|
||||
auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D();
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConv2D(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsConv2D();
|
||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||
ASSERT_EQ(val->group, 1);
|
||||
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
|
||||
|
|
|
@ -28,13 +28,12 @@ TEST_F(TestTfliteParserDeConv, OpType) {
|
|||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type";
|
||||
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DeConv2D) << "wrong Op Type";
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_DeConv2D) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserDeConv, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDeConv2D(), nullptr);
|
||||
auto val = meta_graph->nodes.at(1)->primitive->value.AsDeConv2D();
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDeConv2D(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsDeConv2D();
|
||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||
ASSERT_EQ(val->group, 1);
|
||||
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
|
||||
|
|
|
@ -28,13 +28,12 @@ TEST_F(TestTfliteParserDepthwiseConv1, OpType) {
|
|||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type";
|
||||
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserDepthwiseConv1, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr);
|
||||
auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D();
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConv2D(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsConv2D();
|
||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||
ASSERT_EQ(val->group, 0);
|
||||
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
|
||||
|
@ -64,13 +63,12 @@ TEST_F(TestTfliteParserDepthwiseConv2, OpType) {
|
|||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type";
|
||||
ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DepthwiseConv2D) << "wrong Op Type";
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_DepthwiseConv2D) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserDepthwiseConv2, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D(), nullptr);
|
||||
auto val = meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D();
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthwiseConv2D(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsDepthwiseConv2D();
|
||||
ASSERT_EQ(val->format, schema::Format_NHWC);
|
||||
ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION);
|
||||
ASSERT_EQ(val->hasBias, true);
|
||||
|
|
|
@ -221,12 +221,29 @@ STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) {
|
|||
if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
|
||||
auto attr = op->primitive->value.AsDepthwiseConv2D();
|
||||
if (attr->channelMultiplier > 1) {
|
||||
// update attr
|
||||
std::unique_ptr<schema::Conv2DT> conv_attr(new schema::Conv2DT);
|
||||
// get channel attr
|
||||
if (op->inputIndex.empty()) {
|
||||
MS_LOG(ERROR) << "the input of DepthwiseConv2D is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto data_id = op->inputIndex[0];
|
||||
if (sub_graph->allTensors.size() <= data_id) {
|
||||
MS_LOG(ERROR) << "the number of allTensors is less than " << data_id;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto &data_tensor = sub_graph->allTensors.at(data_id);
|
||||
if (data_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the data tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto data_shape = data_tensor->dims;
|
||||
conv_attr->channelIn = data_shape[3];
|
||||
conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier;
|
||||
|
||||
// update attr
|
||||
conv_attr->group = 0;
|
||||
conv_attr->format = attr->format;
|
||||
conv_attr->channelIn = attr->channelIn;
|
||||
conv_attr->channelOut = attr->channelIn * attr->channelMultiplier;
|
||||
conv_attr->kernelH = attr->kernelH;
|
||||
conv_attr->kernelW = attr->kernelW;
|
||||
conv_attr->strideH = attr->strideH;
|
||||
|
|
Loading…
Reference in New Issue