From 80ed2e1a12470aa70dd81f7349a4b1bed8ad4f98 Mon Sep 17 00:00:00 2001 From: lyvette Date: Mon, 17 Aug 2020 15:41:44 +0800 Subject: [PATCH] modify depthwise_conv tflite parser and conv-related model --- .../parser/tflite/test_data/conv.tflite | Bin 912 -> 836 bytes .../parser/tflite/test_data/deconv.tflite | Bin 976 -> 904 bytes .../tflite/test_data/depthwise_conv1.tflite | Bin 944 -> 852 bytes .../tflite/test_data/depthwise_conv2.tflite | Bin 936 -> 772 bytes .../parser/tflite/tflite_conv_parser_test.cc | 7 +++--- .../tflite/tflite_deconv_parser_test.cc | 7 +++--- .../tflite_depthwise_conv_parser_test.cc | 14 +++++------ .../parser/tflite/tflite_model_parser.cc | 23 +++++++++++++++--- 8 files changed, 32 insertions(+), 19 deletions(-) diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/conv.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/conv.tflite index 2cb33152c33ca7603cd015dfc0d16f71c1fb5e18..78614e5be67f1aeae9899266a56219f4fe39e70f 100644 GIT binary patch literal 836 zcmd6lu}T9$6h&W5G)oLIXb_7ODJ(=Jn22C!p@>3+Aod}cg)Af+S);Y3WeRJ*!OyVr zEBpWpD^b^bW(>(EIB@sPym@oy-Zxt{v&)m#p4qBZEoXrht!x!#HFoOi7Ao`b)~7dsEmFb_6)!EK1jlMD>M< z*63pp@k^`<)DF8*4EMU`=+8?L+!`4;YXTb03eF=63R=N21)PAPt6kpL=e6?nnq;Ye z`bksxxEl@VXHKu+1+3{e_M69g7oFKoeZTNyc7OPvkVqSz=)t@vd`x?Tce>#{HSJkv za2d!gUa9qc)?{wH{&CY+_Th(%`Wautv5?~}D{j~S>AWT9IKx@~73${u&EB%m9(1q5 P{%Ft(6Z>-g&biuO%6MVN literal 912 zcmZ{iy-EX75QRsx#-LGyrm#q1A!s4NXb=kvK@o%%M67Ir=z$l$oml+S;IJ;wsJj1|WmFz!m=l@;eXI&a z!B>xOYU)Xd1W&k-pZPfYbQqAucQ3FD`f2ZnajKPWZBQy`>GY8|C7=uUI!The7M6Ni zi95upeJPGif%x86k}xklqBkwTBX1}J-^3YxVcqOTIYikn)5D_%JZfUm;ht*4W%DUl zmloqJA?;+#T8#B%ef9KT17BjF3&`G1Y;z&sU)v$k*{Cj1^qc!}o`2*1I_!@^bZ`a0zYH&VDS{$C{S= zw1e|%eNFgY@ja)_K^Hi4wim1Y$>za9zp{B=yjhO9`3Bru0DEqGy_)YN(35H6#le$` v5oR(nmfThIRrYToH~GlLnmBKaqbAORorf*hy^-Hn!*DcEk$jnV%A43<1wn4= literal 976 zcmZva&q^Cn6o*gkXo8J4S{EW&xG3nNq>Z|$;Gz&I6e5(;{Scgi3?viDjHQ&W#Fe=0 zqKo3fh0joY5}%;-A#BIrcV{jQp+i3I&H4K~XQpLlAK&(N%~q{tfz_>HOVe^Hi6k-n$Q*-Epe;JdWD3EKB{rS(asCh8JQ_ zpyQ2L5SP-oCw@y4zZ%U-|&O72xT^Bi;d znCH5CjsL6Lk^j);El6$PcI0zs_v7DBt@^I3_v8A}pT6wY_hZJ{IbJ5|$!Q*D$(J~j z+lK0WE68P@>Am=!pof}E;{C9H91q6BD4v99M=^f1t+`wI{h*DS;tc!dC&j= diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv1.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv1.tflite index d820f874f942c1e29219657293cb3911f5041b92..d658f804ccb9e9ea102b86e27bb97197db88318a 100644 GIT binary patch literal 852 zcmd5)y-LGS7(LclV=c9)gF}W6A`TL)BDgsSf)GK(Q7FY+XrxI?8h=JdM-d;xCvo;U zoE+77&P}65U%(3|-@V_@$@#u}JpiW%!8Wi64-Q<^(Zn*nEA#-1j~$kNyTd(!ENo6B~p$#OYDj3q=hHX_r1NS+yYR z5z39bQ6g?}N&hSriJPk;KJiPVrdD>SUR5ie664X(!@{v2s98qU)ocWg@*h6)Tx?b= zt@LZJc_~lOd5#FpFza1MBg6eI>(=sOZudOXv-h8HL2c?wv(S9YzSMI$AN8gFzWdq& zo>-?B=CW)1&KMWZ@AW+9*L)Rk5#s*C*_N+r{l(oXqj>qcTh*#=Qsd<86`LoH``vTX O8}vhy;@!^Ov7BF0>1BNY literal 944 zcmaKq%PvDv6owZLE$R{)G3da+z(5+;#K1s81R)XWz(CT7o}?#fPo;_%FcL#U@f>1g z>>TRLxu71{C4Ip3U9BL*^JT-2ku+&qt4Q9UeQgw3<4 z5!=gqk_57iiuvMlaZbBmBRcHuZXZmoH)8p>!b(v$!>azfmtoz|!>m{m)UhV|L{0VO z1-)E}NkN~wa*KMQS8(?vm<9FJkE1x%N~e~Q3R+s-#8Bt7pl+s0r*uvK#QWSnzeeO1 zHngsjB+(XN-}y)qGEg)6GXg*A8fe4+eo+ZeTh+_xEY$f%ad@PfS?5b0ukW?cNcHjM zHBw(4Ca^uX_0j9p_i@CMi(0<)^;JD6?p}RhL-)78pI=6*e0P3+zD9QRi2IX&LGR~g z)Te3#Hyf`b_aBQ{l8ZY3%{iaud~Eyk4(QJbauW}}o*|w0YsV*HwNa~t&9cH>CHT@? d-bJO~2fd+9Z08I+-w@lN=*VerZhnEw^a;xbW;_4@ diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv2.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv2.tflite index 9baa40effdd1924ae945bfcee0a4706f631adf3b..56190f5ee8e4d682abe011ef41fca7cb9ce999ce 100644 GIT binary patch delta 353 zcmZ3%-oiFP!%TvKfuRG4L9_u7GXk*>kO0wBfS3)4w=gjRhSb+EpP+AetG787CVu3X6yUSsV;J z4D1X7418dD2Byg|j5?+uB_Ip}Kx?7exxh+6TDZV63@jky9{v0OUjRt6Oy0<-pLhYt z%lrTTKgbatU{wwF3=EXxBA}TdTbL&sGU>1w0fS}oK_+>@Q&3@$BSF#%lRq*^PxfJs F0RTF9Ip6>Q literal 936 zcmZXS&q@MO6vj_Asg;$KgbNofTC@n!{y>WsK@bE65-tOg9B4qtP^%Yc)!JqBAgz3c zR;{83siyBY<1MR8ANO|e`Tm^a6wR!DT-`BSv7&jFv8?4RuiJud%w`1sD+grMCC31+xNoBY|0vr_&6=#4DLvS}u5LO~WDQhCptnN~d&9|4!HK=xapr zLL0C9D2lXs)SP`r5ej6+e^a0i-uxJF=7g4iHMqPD`o7LX@uv)QtMyVidkE=Pvd&t@huBz!goEAWl^ZLn-~6Nx83r44bkHS g)Y4NvW+mSRJ>d;+a&~`zy9K>ZXxh87c+x@h3+A0>5C8xG diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc index dc2a763d499..202079abc64 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc @@ -28,13 +28,12 @@ TEST_F(TestTfliteParserConv, OpType) { ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type"; - ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type"; + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type"; } TEST_F(TestTfliteParserConv, AttrValue) { - ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr); - auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D(); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConv2D(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsConv2D(); ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->group, 1); ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc index 9ff4704c7c2..8badaa1d334 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc @@ -28,13 +28,12 @@ TEST_F(TestTfliteParserDeConv, OpType) { ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type"; - ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DeConv2D) << "wrong Op Type"; + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_DeConv2D) << "wrong Op Type"; } TEST_F(TestTfliteParserDeConv, AttrValue) { - ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDeConv2D(), nullptr); - auto val = meta_graph->nodes.at(1)->primitive->value.AsDeConv2D(); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDeConv2D(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsDeConv2D(); ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->group, 1); ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc index 6ba9c4d1e68..243bbcc867a 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc @@ -28,13 +28,12 @@ TEST_F(TestTfliteParserDepthwiseConv1, OpType) { ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type"; - ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type"; + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type"; } TEST_F(TestTfliteParserDepthwiseConv1, AttrValue) { - ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr); - auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D(); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConv2D(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsConv2D(); ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->group, 0); ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); @@ -64,13 +63,12 @@ TEST_F(TestTfliteParserDepthwiseConv2, OpType) { ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type"; - ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DepthwiseConv2D) << "wrong Op Type"; + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_DepthwiseConv2D) << "wrong Op Type"; } TEST_F(TestTfliteParserDepthwiseConv2, AttrValue) { - ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D(), nullptr); - auto val = meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D(); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthwiseConv2D(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsDepthwiseConv2D(); ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); ASSERT_EQ(val->hasBias, true); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index b9990d0ace7..6c25cc895b9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -222,12 +222,29 @@ STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT* sub_graph) { if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { auto attr = op->primitive->value.AsDepthwiseConv2D(); if (attr->channelMultiplier > 1) { - // update attr std::unique_ptr conv_attr(new schema::Conv2DT); + // get channel attr + if (op->inputIndex.empty()) { + MS_LOG(ERROR) << "the input of DepthwiseConv2D is null"; + return RET_NULL_PTR; + } + auto data_id = op->inputIndex[0]; + if (sub_graph->allTensors.size() <= data_id) { + MS_LOG(ERROR) << "the number of allTensors is less than " << data_id; + return RET_ERROR; + } + auto &data_tensor = sub_graph->allTensors.at(data_id); + if (data_tensor == nullptr) { + MS_LOG(ERROR) << "the data tensor is null"; + return RET_NULL_PTR; + } + auto data_shape = data_tensor->dims; + conv_attr->channelIn = data_shape[3]; + conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier; + + // update attr conv_attr->group = 0; conv_attr->format = attr->format; - conv_attr->channelIn = attr->channelIn; - conv_attr->channelOut = attr->channelIn * attr->channelMultiplier; conv_attr->kernelH = attr->kernelH; conv_attr->kernelW = attr->kernelW; conv_attr->strideH = attr->strideH;