From 123c2024a547ed4a32c112b5104709ba7120a1bf Mon Sep 17 00:00:00 2001 From: lyvette Date: Wed, 12 Aug 2020 17:26:51 +0800 Subject: [PATCH] refactor tflite parsers. append ut refactor tflite parsers modify tflite parser, ut and model supplement caffe flatten parser fix the weight tensor format of deconv bug fix bug when idx=-1 fix the weight tensor format of depthConv bug. --- mindspore/lite/test/run_test.sh | 28 +- .../test_data/{add1.tflite => add.tflite} | Bin .../parser/tflite/test_data/add2.tflite | Bin 560 -> 0 bytes .../parser/tflite/test_data/add3.tflite | Bin 564 -> 0 bytes .../parser/tflite/test_data/argmax.tflite | Bin 0 -> 604 bytes .../parser/tflite/test_data/concat.tflite | Bin 0 -> 580 bytes .../parser/tflite/test_data/conv.tflite | Bin 0 -> 912 bytes .../parser/tflite/test_data/deconv.tflite | Bin 0 -> 976 bytes .../tflite/test_data/depthwise_conv1.tflite | Bin 0 -> 944 bytes .../tflite/test_data/depthwise_conv2.tflite | Bin 0 -> 936 bytes .../test_data/{div1.tflite => div.tflite} | Bin .../parser/tflite/test_data/div2.tflite | Bin 572 -> 0 bytes .../parser/tflite/test_data/div3.tflite | Bin 572 -> 0 bytes .../parser/tflite/test_data/hardswish.tflite | Bin 0 -> 620 bytes .../{sigmoid.tflite => logistic.tflite} | Bin ...avg_pooling.tflite => mean_pooling.tflite} | Bin .../test_data/{mul1.tflite => mul.tflite} | Bin .../parser/tflite/test_data/mul2.tflite | Bin 564 -> 0 bytes .../parser/tflite/test_data/mul3.tflite | Bin 568 -> 0 bytes .../parser/tflite/test_data/realdiv.tflite | Bin 0 -> 588 bytes .../parser/tflite/test_data/slice.tflite | Bin 0 -> 676 bytes .../parser/tflite/test_data/stack.tflite | Bin 0 -> 596 bytes .../test_data/{sub1.tflite => sub.tflite} | Bin .../parser/tflite/test_data/sub2.tflite | Bin 564 -> 0 bytes .../parser/tflite/test_data/sub3.tflite | Bin 568 -> 0 bytes .../parser/tflite/test_data/transpose.tflite | Bin 0 -> 584 bytes .../tflite/tflite_activation_parser_test.cc | 71 ++- .../parser/tflite/tflite_addn_parser_test.cc | 6 +- .../tflite/tflite_argmax_parser_test.cc | 44 ++ .../tflite/tflite_argmin_parser_test.cc | 5 +- .../tflite/tflite_arithmetic_parser_test.cc | 241 ++-------- .../tflite_batch_to_space_nd_parser_test.cc | 12 +- .../parser/tflite/tflite_cast_parser_test.cc | 9 +- .../tflite/tflite_concat_parser_test.cc | 41 ++ .../parser/tflite/tflite_conv_parser_test.cc | 57 +++ .../tflite/tflite_deconv_parser_test.cc | 58 +++ .../tflite_depth_to_space_parser_test.cc | 8 +- .../tflite_depthwise_conv_parser_test.cc | 92 ++++ .../parser/tflite/tflite_fill_parser_test.cc | 8 +- .../tflite/tflite_gather_nd_parser_test.cc | 5 +- .../tflite/tflite_gather_parser_test.cc | 5 +- .../parser/tflite/tflite_lrn_parser_test.cc | 5 +- .../tflite/tflite_one_hot_parser_test.cc | 7 +- .../parser/tflite/tflite_pad_parser_test.cc | 6 +- .../tflite/tflite_pooling_parser_test.cc | 12 +- .../tflite/tflite_reduce_parser_test.cc | 40 +- .../tflite/tflite_reshape_parser_test.cc | 7 +- .../tflite/tflite_resize_parser_test.cc | 12 +- .../tflite/tflite_reverse_parser_test.cc | 6 +- .../tflite_reverse_sequence_parser_test.cc | 12 +- .../parser/tflite/tflite_slice_parser_test.cc | 45 ++ .../tflite/tflite_softmax_parser_test.cc | 8 +- .../tflite_space_to_batch_nd_parser_test.cc | 12 +- .../tflite_space_to_depth_parser_test.cc | 8 +- .../tflite_sparse_to_dense_parser_test.cc | 18 +- .../parser/tflite/tflite_split_parser_test.cc | 12 +- .../tflite/tflite_split_v_parser_test.cc | 12 +- .../parser/tflite/tflite_stack_parser_test.cc | 44 ++ .../tflite_strided_slice_parser_test.cc | 28 +- .../parser/tflite/tflite_tile_parser_test.cc | 8 +- .../tflite/tflite_topk_v2_parser_test.cc | 11 +- .../tflite/tflite_transpose_parser_test.cc | 43 ++ .../tflite/tflite_unique_parser_test.cc | 7 +- .../tflite/tflite_unstack_parser_test.cc | 8 +- .../node/weight_format_pass.cc | 4 +- .../parser/caffe/caffe_flatten_parser.cc | 11 +- .../parser/tflite/tflite_activation_parser.cc | 98 +++-- .../parser/tflite/tflite_activation_parser.h | 46 +- .../parser/tflite/tflite_addn_parser.cc | 28 +- .../parser/tflite/tflite_addn_parser.h | 14 +- .../parser/tflite/tflite_argmax_parser.cc | 28 +- .../parser/tflite/tflite_argmax_parser.h | 20 +- .../parser/tflite/tflite_argmin_parser.cc | 29 +- .../parser/tflite/tflite_argmin_parser.h | 20 +- .../parser/tflite/tflite_arithmetic_parser.cc | 220 ++++------ .../parser/tflite/tflite_arithmetic_parser.h | 43 +- .../tflite/tflite_batch_to_space_parser.cc | 27 +- .../tflite/tflite_batch_to_space_parser.h | 13 +- .../tflite/tflite_broadcast_to_parser.cc | 23 +- .../tflite/tflite_broadcast_to_parser.h | 14 +- .../parser/tflite/tflite_cast_parser.cc | 31 +- .../parser/tflite/tflite_cast_parser.h | 14 +- .../parser/tflite/tflite_concat_parser.cc | 29 +- .../parser/tflite/tflite_concat_parser.h | 20 +- .../parser/tflite/tflite_conv_parser.cc | 90 ++-- .../parser/tflite/tflite_conv_parser.h | 20 +- .../parser/tflite/tflite_converter.h | 1 + .../parser/tflite/tflite_deconv_parser.cc | 64 ++- .../parser/tflite/tflite_deconv_parser.h | 14 +- .../tflite/tflite_depth_to_space_parser.cc | 24 +- .../tflite/tflite_depth_to_space_parser.h | 14 +- .../tflite/tflite_depthwise_conv_parser.cc | 128 ++---- .../tflite/tflite_depthwise_conv_parser.h | 25 +- .../parser/tflite/tflite_dequantize_parser.cc | 41 +- .../parser/tflite/tflite_dequantize_parser.h | 19 +- .../tflite/tflite_expand_dims_parser.cc | 17 +- .../parser/tflite/tflite_expand_dims_parser.h | 20 +- .../parser/tflite/tflite_fakequant_parser.cc | 75 ---- .../parser/tflite/tflite_fakequant_parser.h | 39 -- .../parser/tflite/tflite_fill_parser.cc | 31 +- .../parser/tflite/tflite_fill_parser.h | 20 +- .../tflite/tflite_fullyconnected_parser.cc | 66 ++- .../tflite/tflite_fullyconnected_parser.h | 24 +- .../parser/tflite/tflite_gather_nd_parser.cc | 50 +-- .../parser/tflite/tflite_gather_nd_parser.h | 20 +- .../parser/tflite/tflite_gather_parser.cc | 48 +- .../parser/tflite/tflite_gather_parser.h | 20 +- .../parser/tflite/tflite_hard_swish_parser.cc | 50 --- .../parser/tflite/tflite_hard_swish_parser.h | 41 -- .../parser/tflite/tflite_logical_parser.cc | 32 +- .../parser/tflite/tflite_logical_parser.h | 20 +- .../parser/tflite/tflite_lrn_parser.cc | 24 +- .../parser/tflite/tflite_lrn_parser.h | 20 +- .../parser/tflite/tflite_model_parser.cc | 412 +++++++++--------- .../parser/tflite/tflite_model_parser.h | 52 ++- .../parser/tflite/tflite_node_parser.cc | 74 ---- .../parser/tflite/tflite_node_parser.h | 72 ++- .../parser/tflite/tflite_one_hot_parser.cc | 30 +- .../parser/tflite/tflite_one_hot_parser.h | 20 +- .../parser/tflite/tflite_pad_parser.cc | 28 +- .../parser/tflite/tflite_pad_parser.h | 19 +- .../parser/tflite/tflite_pooling_parser.cc | 62 ++- .../parser/tflite/tflite_pooling_parser.h | 20 +- .../parser/tflite/tflite_range_parser.cc | 26 +- .../parser/tflite/tflite_range_parser.h | 20 +- .../parser/tflite/tflite_rank_parser.cc | 23 +- .../parser/tflite/tflite_rank_parser.h | 20 +- .../parser/tflite/tflite_reduce_parser.cc | 34 +- .../parser/tflite/tflite_reduce_parser.h | 20 +- .../parser/tflite/tflite_reshape_parser.cc | 39 +- .../parser/tflite/tflite_reshape_parser.h | 19 +- .../parser/tflite/tflite_resize_parser.cc | 31 +- .../parser/tflite/tflite_resize_parser.h | 19 +- .../parser/tflite/tflite_reverse_parser.cc | 26 +- .../parser/tflite/tflite_reverse_parser.h | 20 +- .../tflite/tflite_reverse_sequence_parser.cc | 18 +- .../tflite/tflite_reverse_sequence_parser.h | 14 +- .../parser/tflite/tflite_scatter_nd_parser.cc | 48 +- .../parser/tflite/tflite_scatter_nd_parser.h | 20 +- .../parser/tflite/tflite_shape_parser.cc | 23 +- .../parser/tflite/tflite_shape_parser.h | 20 +- .../parser/tflite/tflite_slice_parser.cc | 29 +- .../parser/tflite/tflite_slice_parser.h | 19 +- .../parser/tflite/tflite_softmax_parser.cc | 29 +- .../parser/tflite/tflite_softmax_parser.h | 19 +- .../tflite/tflite_space_to_batch_nd_parser.cc | 16 +- .../tflite/tflite_space_to_batch_nd_parser.h | 14 +- .../tflite/tflite_space_to_depth_parser.cc | 17 +- .../tflite/tflite_space_to_depth_parser.h | 14 +- .../tflite/tflite_sparse_to_dense_parser.cc | 16 +- .../tflite/tflite_sparse_to_dense_parser.h | 14 +- .../parser/tflite/tflite_split_parser.cc | 36 +- .../parser/tflite/tflite_split_parser.h | 20 +- .../parser/tflite/tflite_split_v_parser.cc | 38 +- .../parser/tflite/tflite_split_v_parser.h | 20 +- .../parser/tflite/tflite_squeeze_parser.cc | 25 +- .../parser/tflite/tflite_squeeze_parser.h | 20 +- .../parser/tflite/tflite_stack_parser.cc | 32 +- .../parser/tflite/tflite_stack_parser.h | 19 +- .../tflite/tflite_strided_slice_parser.cc | 14 +- .../tflite/tflite_strided_slice_parser.h | 14 +- .../parser/tflite/tflite_tile_parser.cc | 16 +- .../parser/tflite/tflite_tile_parser.h | 14 +- .../parser/tflite/tflite_topk_v2_parser.cc | 17 +- .../parser/tflite/tflite_topk_v2_parser.h | 14 +- .../parser/tflite/tflite_transpose_parser.cc | 39 +- .../parser/tflite/tflite_transpose_parser.h | 19 +- .../parser/tflite/tflite_unique_parser.cc | 21 +- .../parser/tflite/tflite_unique_parser.h | 26 +- .../parser/tflite/tflite_unstack_parser.cc | 20 +- .../parser/tflite/tflite_unstack_parser.h | 14 +- .../converter/parser/tflite/tflite_util.cc | 87 +++- .../converter/parser/tflite/tflite_util.h | 13 +- .../parser/tflite/tflite_where_parser.cc | 18 +- .../parser/tflite/tflite_where_parser.h | 14 +- .../parser/tflite/tflite_zeros_like_parser.cc | 16 +- .../parser/tflite/tflite_zeros_like_parser.h | 14 +- 177 files changed, 2643 insertions(+), 2223 deletions(-) rename mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/{add1.tflite => add.tflite} (100%) delete mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add2.tflite delete mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add3.tflite create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/argmax.tflite create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/concat.tflite create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/conv.tflite create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/deconv.tflite create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv1.tflite create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv2.tflite rename mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/{div1.tflite => div.tflite} (100%) delete mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div2.tflite delete mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div3.tflite create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/hardswish.tflite rename mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/{sigmoid.tflite => logistic.tflite} (100%) rename mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/{avg_pooling.tflite => mean_pooling.tflite} (100%) rename mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/{mul1.tflite => mul.tflite} (100%) delete mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul2.tflite delete mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul3.tflite create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/realdiv.tflite create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/slice.tflite create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/stack.tflite rename mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/{sub1.tflite => sub.tflite} (100%) delete mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub2.tflite delete mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub3.tflite create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/transpose.tflite create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmax_parser_test.cc create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_concat_parser_test.cc create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_slice_parser_test.cc create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_stack_parser_test.cc create mode 100644 mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_transpose_parser_test.cc delete mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc delete mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h delete mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.cc delete mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.h delete mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc diff --git a/mindspore/lite/test/run_test.sh b/mindspore/lite/test/run_test.sh index 9e9d677d9b3..f048dbb3641 100755 --- a/mindspore/lite/test/run_test.sh +++ b/mindspore/lite/test/run_test.sh @@ -12,19 +12,19 @@ cp -r ${CUR_DIR}/ut/tools/converter/parser/tflite/test_data/* ./ TEST_DATA_DIR=${CUR_DIR}/../../../tests/ut/data/dataset/ cp -fr $TEST_DATA_DIR/testPK ./data -#./lite-test --gtest_filter="*MindDataTestTensorDE*" -#./lite-test --gtest_filter="*MindDataTestEager*" -# -#./lite-test --gtest_filter="TestTfliteParser*" -# -#./lite-test --gtest_filter="*TestHebing*" -# -#./lite-test --gtest_filter=TestFcFp32* -#./lite-test --gtest_filter=TestConv1x1Fp32* -#./lite-test --gtest_filter=TestStrassenFp32* -#./lite-test --gtest_filter=TestDeConvolutionFp32* -# -#./lite-test --gtest_filter=TestPadInt8.* -#./lite-test --gtest_filter=TestDeconvInt8.* +./lite-test --gtest_filter="*MindDataTestTensorDE*" +./lite-test --gtest_filter="*MindDataTestEager*" + +./lite-test --gtest_filter="TestTfliteParser*" + +./lite-test --gtest_filter="*TestHebing*" + +./lite-test --gtest_filter=TestFcFp32* +./lite-test --gtest_filter=TestConv1x1Fp32* +./lite-test --gtest_filter=TestStrassenFp32* +./lite-test --gtest_filter=TestDeConvolutionFp32* + +./lite-test --gtest_filter=TestPadInt8.* +./lite-test --gtest_filter=TestDeconvInt8.* ./lite-test --gtest_filter="TestTfliteParser*" diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add1.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add.tflite similarity index 100% rename from mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add1.tflite rename to mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add.tflite diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add2.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add2.tflite deleted file mode 100644 index c0c379a60ce1c0c4a1170569295fce1531352d7d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 560 zcmYjO!AiqG6dX0$QneQKkV6hV^w5JuiwGVof*@2V;yn}-FwkruM)3nYdGeF`4dT&* zU*SnyXLeKc!R+k5yf-^<69Krnm>dJUh!CKSE%dO@Y{(3-M)(u(?6SZCd4(8q5x)}a z#3v6Q#FTg^PMIT75l&a?ZMO$+0``a^H%awi>b%I3$E>=~OM|Z;YxTWHjN+5{h`fv) zZmuSm*ZuR-u)faHn3*o!{(0{*BhP(eKnw|a)+{o#Ak}`$)q7lRI3nxrVwM}}N=I+f zq8sTlHbX9&Np+#MRx_$!Q0NSF+^nkVG`SbnYNq9z>Mr%@?W83>OVeaTUa%I}phm#e ze^(7K9=|m@>Xi@qQM3rxfB59Oe7$DS&i`4>AP@4W*#^Xa58FIX=d;`FuAHY?<-C0J JzJMB1`32v2I`aSk diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add3.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/add3.tflite deleted file mode 100644 index d36b81e327a2d30f36168106db579e0b0eb4f12c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 564 zcmaKpy-LJD6os!E*JahXsKpj3Qdn3>@V8e4L0Cb=J}f3+V3R;xWgo)I%7^eNd!Y*0*h{11e@qpvlahUDa%lxB!mX z6kowI_$K)RpWp+WQX`;fPFLz(cQ!hiZGk*1;_9)OWO*7trPV`L7F0X%JobjyCmuv6 z(LS+V8yv2Nmp9>gS#WQXCJ_}M*MBqkROCGbyFk9Bp_yjj4NSoz(2`d{xncn1GaqG` z@|~96XN4D%HfwdjruoDR>$-05&ug0X+}vXI;ED4BS97)Rt6sXJca)amS(3zu#3|hO zU%oh+N%PF}*U7KPqr3FJ94BcN^P%W;oL1vqs?G1pyL!uqw06LMTJkWXzlWMsegNV+ BI>rD1 diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/argmax.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/argmax.tflite new file mode 100644 index 0000000000000000000000000000000000000000..6330efc0b8b9b878fa0dce3ac0f29e323eea7945 GIT binary patch literal 604 zcmYL{!Ab&A7==$ZrI?tMgbNogTDS-^YGKO=f4!7tpL6e=d;Y6!NjDe6V`)ciMXIQ#hMHDeR-|>Kc#~d(Djk^Dj1xcLH)F%7 zm86Psi)bF0Xln+(xK~0i_tLI0PttMrlrEEbJbsR|$7GS}>nG1yZx~0t(_Y`as%{GJ?XN`AJdyc7Hib_>qI0WTwyx&%>$jJPj9a4Oje_xf#0RpMV$NaL(#Z@gyJn Ww~VXV!rCvG5H`R7 literal 0 HcmV?d00001 diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/concat.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/concat.tflite new file mode 100644 index 0000000000000000000000000000000000000000..11d02991540dd2920662a75d808448bb0ff69885 GIT binary patch literal 580 zcmZ9JK}*9x5QU#LYN=X_7U>~}96WfCU=g8rK@h5lcn=|6FwkruMt^|^@#HTOe~t$a z9y|*1y=lej!rPgh?0)+u8v?jK9UlSP2oa!-4fN1wb-)U+OwCrngH?e8#wBXRN%5Il zp+33zpx&uBYQ!1?i7>9{ZOg%%fE}vHeNx@|Ixk#u@2cCp^tfo6M&EnXFg}hC8GptW zlk@S}RdiB%?yub>W~EEFzuvp7$a6&PQ}Qc@^+H3h%pa*1r6G@!)`}se`-?R9V(&87 ze_5j=u{k<>oCcJ716KJx;QW`h$mbCwh}5w&Xz8ve_NqSRO}uEHYb_6o;%24K(i*6_ z|7Ctge4SB$+9#=|xJg}3)6CtJ(}}B+A=h#!rh4jIHhXRLBoFc$l$oml+S;IJ;wsJj1|WmFz!m=l@;eXI&a z!B>xOYU)Xd1W&k-pZPfYbQqAucQ3FD`f2ZnajKPWZBQy`>GY8|C7=uUI!The7M6Ni zi95upeJPGif%x86k}xklqBkwTBX1}J-^3YxVcqOTIYikn)5D_%JZfUm;ht*4W%DUl zmloqJA?;+#T8#B%ef9KZ|$;Gz&I6e5(;{Scgi3?viDjHQ&W#Fe=0 zqKo3fh0joY5}%;-A#BIrcV{jQp+i3I&H4K~XQpLlAK&(N%~q{tfz_>HOVe^Hi6k-n$Q*-Epe;JdWD3EKB{rS(asCh8JQ_ zpyQ2L5SP-oCw@y4zZ%U-|&O72xT^Bi;d znCH5CjsL6Lk^j);El6$PcI0zs_v7DBt@^I3_v8A}pT6wY_hZJ{IbJ5|$!Q*D$(J~j z+lK0WE68P@>Am=!pof}E;{C9H91q6BD4v99M=^f1t+`wI{h*DS;tc!dC&j= literal 0 HcmV?d00001 diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv1.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv1.tflite new file mode 100644 index 0000000000000000000000000000000000000000..d820f874f942c1e29219657293cb3911f5041b92 GIT binary patch literal 944 zcmaKq%PvDv6owZLE$R{)G3da+z(5+;#K1s81R)XWz(CT7o}?#fPo;_%FcL#U@f>1g z>>TRLxu71{C4Ip3U9BL*^JT-2ku+&qt4Q9UeQgw3<4 z5!=gqk_57iiuvMlaZbBmBRcHuZXZmoH)8p>!b(v$!>azfmtoz|!>m{m)UhV|L{0VO z1-)E}NkN~wa*KMQS8(?vm<9FJkE1x%N~e~Q3R+s-#8Bt7pl+s0r*uvK#QWSnzeeO1 zHngsjB+(XN-}y)qGEg)6GXg*A8fe4+eo+ZeTh+_xEY$f%ad@PfS?5b0ukW?cNcHjM zHBw(4Ca^uX_0j9p_i@CMi(0<)^;JD6?p}RhL-)78pI=6*e0P3+zD9QRi2IX&LGR~g z)Te3#Hyf`b_aBQ{l8ZY3%{iaud~Eyk4(QJbauW}}o*|w0YsV*HwNa~t&9cH>CHT@? d-bJO~2fd+9Z08I+-w@lN=*VerZhnEw^a;xbW;_4@ literal 0 HcmV?d00001 diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv2.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/depthwise_conv2.tflite new file mode 100644 index 0000000000000000000000000000000000000000..9baa40effdd1924ae945bfcee0a4706f631adf3b GIT binary patch literal 936 zcmZXS&q@MO6vj_Asg;$KgbNofTC@n!{y>WsK@bE65-tOg9B4qtP^%Yc)!JqBAgz3c zR;{83siyBY<1MR8ANO|e`Tm^a6wR!DT-`BSv7&jFv8?4RuiJud%w`1sD+grMCC31+xNoBY|0vr_&6=#4DLvS}u5LO~WDQhCptnN~d&9|4!HK=xapr zLL0C9D2lXs)SP`r5ej6+e^a0i-uxJF=7g4iHMqPD`o7LX@uv)QtMyVidkE=Pvd&t@huBz!goEAWl^ZLn-~6Nx83r44bkHS g)Y4NvW+mSRJ>d;+a&~`zy9K>ZXxh87c+x@h3+A0>5C8xG literal 0 HcmV?d00001 diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div1.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div.tflite similarity index 100% rename from mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div1.tflite rename to mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div.tflite diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div2.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div2.tflite deleted file mode 100644 index bc367d8f1e800dcb2ce78994aebcd66df8943423..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 572 zcmYjO!AiqW5F9mXt+5vM&_jeAdMF~%B8bO=AP5yiJa`MmJTTBC5Yy7L_yK;9AK=Mv z@aRGO0x{0KH0r{<-M87Doqh2DT%Qe(fGv1%(143Jc9`uj11u5M4S2L!V4r@0=<$$$ zB9@5{Ufzit;*B_Aj*g75ykc+I7@P!b6G@zg`6Mmk#Dr6m-^W>s&+oF-_cqZFj)Oz` z+BO+n49~Cp(=26uVWNPU7S;Y~Z!sgyKG7w5gfwdxmx5mCXN0uL$O+?u0CV+>HWmFMilwPMDaB2 zGk#`0utFAxXH}y*U@&-{bE#i?q+Pa7{Bj)3Y-iP@eH2xHS^G$fbZO5n@n6dZ@0a7z St+~s_k;$!|FV1wxA(kIm`#)m< diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div3.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/div3.tflite deleted file mode 100644 index a8b7e5ec7b4223fa67a9152c47d1d7a201603176..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 572 zcmaKpy-LJT5QRsL>$+%M)FMTM6c!?~_=mM12*L^?_F*vr1Dgb5cCqjg#M%e(349U@ z3kx4Xh~G&JZsmn@=ZBed&&>7B?B;xUWVUX;Io7om>)R$-kBr#@Xd<(xI%~Fvodd^Q zvM*o}d=Y$t5AY67$zfo$npf=I+BaS@+W={jg~cQ*lQa&eadDsI8QBhX4tG=6#~ut$ z2K(6dtl?Y^FRp{LJfp9SqX8Kg)qgX%WTZU+J3zX{p_v?ny<#81JkXI=L9wh4q%$2Q zh;nN#?PrM%iJP{Yi>l_~@6J&YSsX=5@{N`gC$^*0!!ytZq$ntV7f#VzvSL*M&uCkFe|TzGij^ z{OlDr!R4X@4m4pzO!rNYwmQ^v-|)?LAq=8v@))H-=ue;gJLnIF zeKjx74)$y^zU-aE5p}8W^@+5Y{BMsI5zXI&eK-I~Xs;oIePUD4j5@N6o%4p(-kLL< z1q>za;)QCcN}IE%8W*&qa=(7EY*mL)uKSW@8E#v5H~89V`My1T_0=0nV++Qf&!eQz z=lF}fwo&$|TZ5SDl`Fj^xE$rxutgqDOZOjTPxV>8$!bqkTm^Sw9C+wfb8YS1$-QE4 d?X14i`%hz=_{Ds7>)*$7&rj@|o)zjf=Pyh8Jd^+c literal 0 HcmV?d00001 diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sigmoid.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/logistic.tflite similarity index 100% rename from mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sigmoid.tflite rename to mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/logistic.tflite diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/avg_pooling.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mean_pooling.tflite similarity index 100% rename from mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/avg_pooling.tflite rename to mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mean_pooling.tflite diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul1.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul.tflite similarity index 100% rename from mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul1.tflite rename to mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul.tflite diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul2.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul2.tflite deleted file mode 100644 index 089de5cb2d746a2aaa83a836bbace87f723953da..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 564 zcmYjOu}T9$6dd(5A!rOSLW(O=q(~vBK?KWyAP5PFq(6c=@ZfF_&P4J7R#twJg&IGsZxO@jEIJ`q zvCqw9e09^iC=BZt$vk4FNwdBkuKw~%SAJ(u4ApO`lX@px((N*?uJjdm49Nr=BnOR_o+|sCq40GnZ-l$3u}QL zDg<2pchvx+(c4-_z49VYiW=eik5{hC*J}>#{NL3a@*qE&uTT8g M&dWFN3#cKLU(g>rd;kCd diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul3.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/mul3.tflite deleted file mode 100644 index a0b90cde827a2fbf8c03c970d7b69a82d0762c9d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 568 zcmaKpy-EW?6os$4#t?K3Q6a?^DJ(2x@h`m~2tooP_De7eF60 zVN-kotKf^|6MTSoa7vAUqB&iux82!z$!rT`NgkE+e3oQ!bRU;@Ns&`+!SmQ_UYB?n z4#NXtvo<-5M;BLvvm)o-ES`o`d|dza;8T(J0oVibEe-V~W3Plqu&kUVuYz*L0LW)H zNigL*EzM_*2}zr^(`QqE;ssSzHTM@a%{uOurqF@M>WTLZU%fRm)_c0ASxQfFJWrz| z;se~bU%oi%OFb>}H^?v3$xVD)q|>;J_*8USPOJ7W)%u6!UvJBYw06ONTJrGC`yOgi F`2m6sIxheK diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/realdiv.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/realdiv.tflite new file mode 100644 index 0000000000000000000000000000000000000000..13f3cf78a1beec45a092d99ebd5889c73626134b GIT binary patch literal 588 zcmZvZK}!Nr6opS^I#`)h3l|11T112-MeCpl0)vqDB012&ID(n0R;~L1{gmiewP?{o zuH=V}iFr#lA_nTBIxc;A^LdD)2;1JliH9S)lV;{s9unK(a)o|`u2lhGWC77xN zE$^qqhOEumuW@-k^P<^oHt%25l**yYTPVZh9GUkD-?Q0|-#P%+aqfK@&P$*Z_tL9m zqLiM`zdpa=KF_qQZ;Sd}+Ixr}hiN|^bz6Z~4BUJkcjbTF^VrLLZtrWE<_7fzO-=Dq HfNAb8`tv&< literal 0 HcmV?d00001 diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/slice.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/slice.tflite new file mode 100644 index 0000000000000000000000000000000000000000..79891fd475731a2fdeddea83240a0a4f039c88b3 GIT binary patch literal 676 zcmYjPJx>Bb6ddtHf&_^YLqS17VPQ^*7PdxXG{J-@Jwe#$Cih_uwBQ%`11v49DXsl! z#&c$mjlN{=y?Z-5Z+7q60EU_C zVF?%zA8cL`7pza&7pF!@TN5g-1~!;?iP$Ei>^Yg+*o$z1SRoM~+@&vU<5`W?d0^IE!-yU0!qGsyoWG)wm< zj^{1e=W9uN(ok*_eqBDD&f&~1DE997WjYGnxB9P1>MC@H6*@2e0o@_P&Ct0tqeJ)^ zTda#2_vV4uBxUL9TQpr&{7ur)UhUJlyM#Oo(p8q$AR151V;W6OhHrQ7z^VKKki0`o literal 0 HcmV?d00001 diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/stack.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/stack.tflite new file mode 100644 index 0000000000000000000000000000000000000000..2689d11685577cbd40dde6d2324f1806e7b26e2d GIT binary patch literal 596 zcmZ9J!AiqG7=(W{YN<6A_0U5OJ$UdS!Qj!0P!WU*qTWMj0|uJ}(&!_2@&SAr&pwie zLi{FaFt{)~yZ`<>nb~+|_H;A7FgrBQ9P8P@hIUMKM8#|iG&{48x@vX?Z-8qi@h8{@ z&L$IBQTzc{Y;`^W(v^CzUegO7fi%g&;x#LiG!Eb5VwvO_)gAgA=B9552YwLv=PcIa z0KS{vK8z=M#=SD0`&3+9|7CEg$om+a0{NDNGI<#LhLuUSHRM&$UhDz+Orr!-Zmp#p zyS$LJ8JE&kUz<_Bm9eWMz3w;6`n8H)rOz4TQT^{#RVi!j4?Y3hc zk~?4Pd+Heh+MnN=-CPn I4r)^Q1!P$^kN^Mx literal 0 HcmV?d00001 diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub1.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub.tflite similarity index 100% rename from mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub1.tflite rename to mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub.tflite diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub2.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub2.tflite deleted file mode 100644 index dbca4e40a1a9f239d2644e2e31442715dcdaffaf..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 564 zcmYjOu}T9$6dd(5A!rOHMT#p@q=+D=K@iJ;AP5me(;v|+c#zwJGm(6Nm6e~A!Y{D0 zwD1e8#dT&chd!8{-FNTJ&f5zCjLwJ0z#c*bsA3xpw3ux(18fk*33&2Z;DEeF^tgy$ zh)v>?hYw;zyc4I)k*El-EA^`1gEs+9B28?Z&+RNpP5fx`ha|K3`f*O*8$>@kiH^u+ z>~eECytwY1WtR0bGl`g~)9s)4Iy3UzA-Y75kY~*j(9kP+LCBjzyuRX)th>{3Vx=n` zy-9^`q{|4~Tr`vFTF$w$ztmuwQS)8ut`&)W)jhLbbCvI^`_!lRlb-lyejE46OV$Eg zR0y~h-&F$)25&1J^~#GpDJq2TKV0~(qF!@oSNvVgArJDS`MSh^4?8?xO~-fUKATQV N?!A2TzJMB1`32_lJCgta diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub3.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/sub3.tflite deleted file mode 100644 index c223cdbf90f32c04494de874f1be4fb1a37b0c72..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 568 zcmaKp%}T>S6os!E+fp!OP+x^ST&!9Thd1VN}E>ONE(Fwi6rTXEqdxN_xV_zXUW z3l}bY2qAuFVz4U*&fK||GiT1d6PVf6$?(8z-2(HhWh>UP4XQmVW=o*T%${7;Y=^i2 z4%rl6z%uwE`2-)}9UM_3plD84>MeIRUNY-~G|8f3mQ9m1j_%{)F3B^h4R{`V#p@9F z!a=x8tk)Wcv*GDQ|2WUMH;u<36(84sJ@{1Qy$`m4d`m+;$=EC55iBZa$*Z7TF#z(J zjuK4yPD}G?G9hWRf*zat6E7&svc5mBY1VePG=(-i*0y-h@YP#AW4)()nx*s@n^`)NX`K#pTlhI9ln@`4Z5%H<$G@MrDU#is)%fH^14{2?J|Fq=coA*7` Gr1Arqb~;J` diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/transpose.tflite b/mindspore/lite/test/ut/tools/converter/parser/tflite/test_data/transpose.tflite new file mode 100644 index 0000000000000000000000000000000000000000..6e7e08695ab39d2761bac9f98e61d9e8b491e466 GIT binary patch literal 584 zcmYLG!AiqG6dY@;u}Ur4Lk~IRAm~A&B8YcE6e@^#FVQYE&}>Q){Q^J05Ac)x6ptP} zNaM_|(S^ypef!?L*-ZrC{&IW@>>xsb5L*~v$n1z2V3qJY;KgNuL-LAZe#EcD8u7)= zCvio55a-O1s0gPk_0YxOBw&}wGdn4tY@Ow1@@&c~D=e1FrqTI;I7!ZuV{*H;8QhGo z@8XNXvc5Lcgqc3w{(J8;BhN9hM;s9Hth+=s^iG};@}?#3{PKvb*?F2-py#|xxyPLR z=&Y?*kCxX_?K*e9mchRA4^7jkuYJduvHAq%Hi)`RZM7&W!}yK$K$j{3zkZ(H%+Fxp z?izBE|4+IrS4T6XDb}4^k2?R;KU$bFmuA=9Ls~-@`M<7O^{cmizt*RClenFyk7icP Or>4X=y$0;poBRM^I5|Q9 literal 0 HcmV?d00001 diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_activation_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_activation_parser_test.cc index 24f2a34899b..84ec400efc5 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_activation_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_activation_parser_test.cc @@ -31,6 +31,12 @@ TEST_F(TestTfliteParserRelu, OpType) { ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; } +TEST_F(TestTfliteParserRelu, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsActivation(); + ASSERT_EQ(val->type, schema::ActivationType_RELU); +} + class TestTfliteParserRelu6 : public TestTfliteParser { public: TestTfliteParserRelu6() = default; @@ -43,6 +49,12 @@ TEST_F(TestTfliteParserRelu6, OpType) { ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; } +TEST_F(TestTfliteParserRelu6, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsActivation(); + ASSERT_EQ(val->type, schema::ActivationType_RELU6); +} + class TestTfliteParserTanh : public TestTfliteParser { public: TestTfliteParserTanh() = default; @@ -55,7 +67,45 @@ TEST_F(TestTfliteParserTanh, OpType) { ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; } -// logistic +TEST_F(TestTfliteParserTanh, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsActivation(); + ASSERT_EQ(val->type, schema::ActivationType_TANH); +} + +class TestTfliteParserLogistic : public TestTfliteParser { + public: + TestTfliteParserLogistic() = default; + void SetUp() override { meta_graph = LoadAndConvert("./logistic.tflite", ""); } +}; + +TEST_F(TestTfliteParserLogistic, OpType) { + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; +} +TEST_F(TestTfliteParserLogistic, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsActivation(); + ASSERT_EQ(val->type, schema::ActivationType_SIGMOID); +} + +class TestTfliteParserHardSwish : public TestTfliteParser { + public: + TestTfliteParserHardSwish() = default; + void SetUp() override { meta_graph = LoadAndConvert("./hardswish.tflite", ""); } +}; + +TEST_F(TestTfliteParserHardSwish, OpType) { + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Activation) << "wrong Op Type"; +} +TEST_F(TestTfliteParserHardSwish, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsActivation(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsActivation(); + ASSERT_EQ(val->type, schema::ActivationType_SIGMOID); +} class TestTfliteParserPrelu : public TestTfliteParser { public: @@ -73,12 +123,11 @@ TEST_F(TestTfliteParserPrelu, OpType) { } TEST_F(TestTfliteParserPrelu, AttrValue) { - std::vector slope(20, 0); - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPrelu(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsPrelu()->slope, slope); + auto val = meta_graph->nodes.front()->primitive->value; + std::vector slope(20, 0); + ASSERT_EQ(val.AsPrelu()->slope, slope); + ASSERT_EQ(val.type, schema::PrimitiveType_Prelu); } class TestTfliteParserLeakyRelu : public TestTfliteParser { @@ -94,12 +143,10 @@ TEST_F(TestTfliteParserLeakyRelu, OpType) { } TEST_F(TestTfliteParserLeakyRelu, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - - auto val = meta_graph->nodes.front()->primitive->value.AsLeakyReLU(); - ASSERT_NE(val, nullptr); - ASSERT_EQ(val->negativeSlope, 0.20000000298023224); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsLeakyReLU(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value; + ASSERT_EQ(val.AsLeakyReLU()->negativeSlope, 0.20000000298023224); + ASSERT_EQ(val.type, schema::PrimitiveType_LeakyReLU); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_addn_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_addn_parser_test.cc index 72a82376736..7480d19b6c0 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_addn_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_addn_parser_test.cc @@ -35,10 +35,8 @@ TEST_F(TestTfliteParserAddN, OpType) { } TEST_F(TestTfliteParserAddN, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsAddN(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsAddN()->N, 4); + auto val = meta_graph->nodes.front()->primitive->value.AsAddN(); + ASSERT_EQ(val->N, 4); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmax_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmax_parser_test.cc new file mode 100644 index 00000000000..465930039d4 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmax_parser_test.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserArgmax : public TestTfliteParser { + public: + TestTfliteParserArgmax() = default; + void SetUp() override { meta_graph = LoadAndConvert("./argmax.tflite", ""); } +}; + +TEST_F(TestTfliteParserArgmax, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMax) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserArgmax, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsArgMax(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsArgMax(); + ASSERT_EQ(val->axis, 1); + ASSERT_EQ(val->topK, 1); + ASSERT_EQ(val->axisType, 1); + ASSERT_EQ(val->keepDims, false); + ASSERT_EQ(val->outMaxValue, false); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmin_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmin_parser_test.cc index 0bb625ec878..03f9e1c1bf2 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmin_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_argmin_parser_test.cc @@ -25,15 +25,14 @@ class TestTfliteParserArgmin : public TestTfliteParser { }; TEST_F(TestTfliteParserArgmin, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_ArgMin) << "wrong Op Type"; } TEST_F(TestTfliteParserArgmin, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsArgMin(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsArgMin(); ASSERT_EQ(val->axis, 1); ASSERT_EQ(val->topK, 1); diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_arithmetic_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_arithmetic_parser_test.cc index 64056ecde41..fdf08db8cf4 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_arithmetic_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_arithmetic_parser_test.cc @@ -19,234 +19,57 @@ namespace mindspore { // doubleInputOp -class TestTfliteParserAdd1 : public TestTfliteParser { +class TestTfliteParserAdd : public TestTfliteParser { public: - TestTfliteParserAdd1() = default; - void SetUp() override { meta_graph = LoadAndConvert("./add1.tflite", ""); } + TestTfliteParserAdd() = default; + void SetUp() override { meta_graph = LoadAndConvert("./add.tflite", ""); } }; -TEST_F(TestTfliteParserAdd1, OpType) { +TEST_F(TestTfliteParserAdd, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type"; } -TEST_F(TestTfliteParserAdd1, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserAdd2 : public TestTfliteParser { +class TestTfliteParserSub : public TestTfliteParser { public: - TestTfliteParserAdd2() = default; - void SetUp() override { meta_graph = LoadAndConvert("./add2.tflite", ""); } + TestTfliteParserSub() = default; + void SetUp() override { meta_graph = LoadAndConvert("./sub.tflite", ""); } }; -TEST_F(TestTfliteParserAdd2, OpType) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type"; -} - -TEST_F(TestTfliteParserAdd2, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserAdd3 : public TestTfliteParser { - public: - TestTfliteParserAdd3() = default; - void SetUp() override { meta_graph = LoadAndConvert("./add3.tflite", ""); } -}; - -TEST_F(TestTfliteParserAdd3, OpType) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Add) << "wrong Op Type"; -} - -TEST_F(TestTfliteParserAdd3, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserSub1 : public TestTfliteParser { - public: - TestTfliteParserSub1() = default; - void SetUp() override { meta_graph = LoadAndConvert("./sub1.tflite", ""); } -}; - -TEST_F(TestTfliteParserSub1, OpType) { +TEST_F(TestTfliteParserSub, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type"; } -TEST_F(TestTfliteParserSub1, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserSub2 : public TestTfliteParser { +class TestTfliteParserMul : public TestTfliteParser { public: - TestTfliteParserSub2() = default; - void SetUp() override { meta_graph = LoadAndConvert("./sub2.tflite", ""); } + TestTfliteParserMul() = default; + void SetUp() override { meta_graph = LoadAndConvert("./mul.tflite", ""); } }; -TEST_F(TestTfliteParserSub2, OpType) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type"; -} - -TEST_F(TestTfliteParserSub2, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserSub3 : public TestTfliteParser { - public: - TestTfliteParserSub3() = default; - void SetUp() override { meta_graph = LoadAndConvert("./sub3.tflite", ""); } -}; - -TEST_F(TestTfliteParserSub3, OpType) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Sub) << "wrong Op Type"; -} - -TEST_F(TestTfliteParserSub3, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserMul1 : public TestTfliteParser { - public: - TestTfliteParserMul1() = default; - void SetUp() override { meta_graph = LoadAndConvert("./mul1.tflite", ""); } -}; - -TEST_F(TestTfliteParserMul1, OpType) { +TEST_F(TestTfliteParserMul, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type"; } -TEST_F(TestTfliteParserMul1, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserMul2 : public TestTfliteParser { +class TestTfliteParserDiv : public TestTfliteParser { public: - TestTfliteParserMul2() = default; - void SetUp() override { meta_graph = LoadAndConvert("./mul2.tflite", ""); } + TestTfliteParserDiv() = default; + void SetUp() override { meta_graph = LoadAndConvert("./div.tflite", ""); } }; -TEST_F(TestTfliteParserMul2, OpType) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type"; -} - -TEST_F(TestTfliteParserMul2, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserMul3 : public TestTfliteParser { - public: - TestTfliteParserMul3() = default; - void SetUp() override { meta_graph = LoadAndConvert("./mul3.tflite", ""); } -}; - -TEST_F(TestTfliteParserMul3, OpType) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Mul) << "wrong Op Type"; -} - -TEST_F(TestTfliteParserMul3, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserDiv1 : public TestTfliteParser { - public: - TestTfliteParserDiv1() = default; - void SetUp() override { meta_graph = LoadAndConvert("./div1.tflite", ""); } -}; - -TEST_F(TestTfliteParserDiv1, OpType) { +TEST_F(TestTfliteParserDiv, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type"; } - -TEST_F(TestTfliteParserDiv1, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserDiv2 : public TestTfliteParser { - public: - TestTfliteParserDiv2() = default; - void SetUp() override { meta_graph = LoadAndConvert("./div2.tflite", ""); } -}; - -TEST_F(TestTfliteParserDiv2, OpType) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type"; -} - -TEST_F(TestTfliteParserDiv2, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_GT(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - -class TestTfliteParserDiv3 : public TestTfliteParser { - public: - TestTfliteParserDiv3() = default; - void SetUp() override { meta_graph = LoadAndConvert("./div3.tflite", ""); } -}; - -TEST_F(TestTfliteParserDiv3, OpType) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type"; -} - -TEST_F(TestTfliteParserDiv3, Tensor) { - ASSERT_GT(meta_graph->allTensors.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(0)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(1)->data.size(), 0); - ASSERT_EQ(meta_graph->allTensors.at(2)->data.size(), 0); -} - class TestTfliteParserFloorDiv : public TestTfliteParser { public: TestTfliteParserFloorDiv() = default; @@ -254,6 +77,7 @@ class TestTfliteParserFloorDiv : public TestTfliteParser { }; TEST_F(TestTfliteParserFloorDiv, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorDiv) << "wrong Op Type"; @@ -266,12 +90,26 @@ class TestTfliteParserFloorMod : public TestTfliteParser { }; TEST_F(TestTfliteParserFloorMod, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_FloorMod) << "wrong Op Type"; } -// realDiv +class TestTfliteParserRealDiv : public TestTfliteParser { + public: + TestTfliteParserRealDiv() = default; + void SetUp() override { + meta_graph = LoadAndConvert("./realdiv.tflite"); + } +}; + +TEST_F(TestTfliteParserRealDiv, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Div) << "wrong Op Type"; +} class TestTfliteParserSquaredDifference : public TestTfliteParser { public: @@ -296,17 +134,15 @@ class TestTfliteParserPow : public TestTfliteParser { }; TEST_F(TestTfliteParserPow, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Power) << "wrong Op Type"; } TEST_F(TestTfliteParserPow, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPower(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsPower(); - ASSERT_EQ(val->scale, 1.0); ASSERT_EQ(val->shift, 0.0); ASSERT_EQ(val->power, 0.0); @@ -477,6 +313,7 @@ class TestTfliteParserFloor : public TestTfliteParser { }; TEST_F(TestTfliteParserFloor, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Floor) << "wrong Op Type"; diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_batch_to_space_nd_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_batch_to_space_nd_parser_test.cc index c78d8eac54b..8091bc65986 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_batch_to_space_nd_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_batch_to_space_nd_parser_test.cc @@ -32,14 +32,12 @@ TEST_F(TestTfliteParserBatchToSpaceNd, OpType) { } TEST_F(TestTfliteParserBatchToSpaceNd, AttrValue) { - const std::vector blockShape{2, 2}; - const std::vector crops{0, 0, 2, 0}; - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsBatchToSpace(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->blockShape, blockShape); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsBatchToSpace()->crops, crops); + auto val = meta_graph->nodes.front()->primitive->value.AsBatchToSpace(); + const std::vector blockShape = {2, 2}; + ASSERT_EQ(val->blockShape, blockShape); + const std::vector crops = {0, 0, 2, 0}; + ASSERT_EQ(val->crops, crops); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cast_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cast_parser_test.cc index 28fe10828c0..4dd4c41a755 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cast_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_cast_parser_test.cc @@ -35,12 +35,9 @@ TEST_F(TestTfliteParserCast, OpType) { } TEST_F(TestTfliteParserCast, AttrValue) { - // float32 --> int32 - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsCast(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->srcT, 43); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsCast()->dstT, 34); + auto val = meta_graph->nodes.front()->primitive->value.AsCast(); + ASSERT_EQ(val->srcT, 43); + ASSERT_EQ(val->dstT, 34); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_concat_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_concat_parser_test.cc new file mode 100644 index 00000000000..2d5fd6fe806 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_concat_parser_test.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserConcat : public TestTfliteParser { + public: + TestTfliteParserConcat() = default; + void SetUp() override { meta_graph = LoadAndConvert("./concat.tflite", ""); } +}; + +TEST_F(TestTfliteParserConcat, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Concat) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserConcat, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConcat(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsConcat(); + ASSERT_EQ(val->axis, 1); + ASSERT_EQ(val->n, 2); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc new file mode 100644 index 00000000000..dc2a763d499 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_conv_parser_test.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserConv : public TestTfliteParser { + public: + TestTfliteParserConv() = default; + void SetUp() override { meta_graph = LoadAndConvert("./conv.tflite", ""); } +}; + +TEST_F(TestTfliteParserConv, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type"; + ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserConv, AttrValue) { + ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr); + auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D(); + ASSERT_EQ(val->format, schema::Format_NHWC); + ASSERT_EQ(val->group, 1); + ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); + ASSERT_EQ(val->hasBias, true); + ASSERT_EQ(val->channelIn, 1); + ASSERT_EQ(val->channelOut, 4); + ASSERT_EQ(val->kernelH, 3); + ASSERT_EQ(val->kernelW, 3); + ASSERT_EQ(val->strideH, 1); + ASSERT_EQ(val->strideW, 1); + ASSERT_EQ(val->dilateH, 1); + ASSERT_EQ(val->dilateW, 1); + ASSERT_EQ(val->padMode, schema::PadMode_SAME); + ASSERT_EQ(val->padUp, 1); + ASSERT_EQ(val->padDown, 1); + ASSERT_EQ(val->padLeft, 1); + ASSERT_EQ(val->padRight, 1); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc new file mode 100644 index 00000000000..9ff4704c7c2 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_deconv_parser_test.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserDeConv : public TestTfliteParser { + public: + TestTfliteParserDeConv() = default; + void SetUp() override { meta_graph = LoadAndConvert("./deconv.tflite", ""); } +}; + +TEST_F(TestTfliteParserDeConv, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type"; + ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DeConv2D) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserDeConv, AttrValue) { + ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDeConv2D(), nullptr); + auto val = meta_graph->nodes.at(1)->primitive->value.AsDeConv2D(); + ASSERT_EQ(val->format, schema::Format_NHWC); + ASSERT_EQ(val->group, 1); + ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); + ASSERT_EQ(val->hasBias, true); + + ASSERT_EQ(val->channelIn, 1); + ASSERT_EQ(val->channelOut, 4); + ASSERT_EQ(val->kernelH, 3); + ASSERT_EQ(val->kernelW, 3); + ASSERT_EQ(val->strideH, 1); + ASSERT_EQ(val->strideW, 1); + ASSERT_EQ(val->dilateH, 1); + ASSERT_EQ(val->dilateW, 1); + ASSERT_EQ(val->padMode, schema::PadMode_SAME); + ASSERT_EQ(val->padUp, 1); + ASSERT_EQ(val->padDown, 1); + ASSERT_EQ(val->padLeft, 1); + ASSERT_EQ(val->padRight, 1); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depth_to_space_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depth_to_space_parser_test.cc index 5283c0a85de..08abf718360 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depth_to_space_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depth_to_space_parser_test.cc @@ -35,11 +35,9 @@ TEST_F(TestTfliteParserDepthToSpace, OpType) { } TEST_F(TestTfliteParserDepthToSpace, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsDepthToSpace(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->blockSize, 4); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsDepthToSpace()->format, schema::Format_NHWC); + auto val = meta_graph->nodes.front()->primitive->value.AsDepthToSpace(); + ASSERT_EQ(val->blockSize, 4); + ASSERT_EQ(val->format, schema::Format_NHWC); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc new file mode 100644 index 00000000000..6ba9c4d1e68 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_depthwise_conv_parser_test.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserDepthwiseConv1 : public TestTfliteParser { + public: + TestTfliteParserDepthwiseConv1() = default; + void SetUp() override { meta_graph = LoadAndConvert("./depthwise_conv1.tflite", ""); } +}; + +TEST_F(TestTfliteParserDepthwiseConv1, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type"; + ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_Conv2D) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserDepthwiseConv1, AttrValue) { + ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsConv2D(), nullptr); + auto val = meta_graph->nodes.at(1)->primitive->value.AsConv2D(); + ASSERT_EQ(val->format, schema::Format_NHWC); + ASSERT_EQ(val->group, 0); + ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); + ASSERT_EQ(val->hasBias, true); + ASSERT_EQ(val->channelIn, 1); + ASSERT_EQ(val->channelOut, 4); + ASSERT_EQ(val->kernelH, 3); + ASSERT_EQ(val->kernelW, 3); + ASSERT_EQ(val->strideH, 1); + ASSERT_EQ(val->strideW, 1); + ASSERT_EQ(val->dilateH, 1); + ASSERT_EQ(val->dilateW, 1); + ASSERT_EQ(val->padMode, schema::PadMode_SAME); + ASSERT_EQ(val->padUp, 1); + ASSERT_EQ(val->padDown, 1); + ASSERT_EQ(val->padLeft, 1); + ASSERT_EQ(val->padRight, 1); +} + +class TestTfliteParserDepthwiseConv2 : public TestTfliteParser { + public: + TestTfliteParserDepthwiseConv2() = default; + void SetUp() override { meta_graph = LoadAndConvert("./depthwise_conv2.tflite", ""); } +}; + +TEST_F(TestTfliteParserDepthwiseConv2, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reshape) << "wrong Op Type"; + ASSERT_EQ(meta_graph->nodes.at(1)->primitive->value.type, schema::PrimitiveType_DepthwiseConv2D) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserDepthwiseConv2, AttrValue) { + ASSERT_NE(meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D(), nullptr); + auto val = meta_graph->nodes.at(1)->primitive->value.AsDepthwiseConv2D(); + ASSERT_EQ(val->format, schema::Format_NHWC); + ASSERT_EQ(val->activationType, schema::ActivationType_NO_ACTIVATION); + ASSERT_EQ(val->hasBias, true); + ASSERT_EQ(val->channelIn, 2); + ASSERT_EQ(val->channelMultiplier, 1); + ASSERT_EQ(val->kernelH, 3); + ASSERT_EQ(val->kernelW, 3); + ASSERT_EQ(val->strideH, 1); + ASSERT_EQ(val->strideW, 1); + ASSERT_EQ(val->dilateH, 1); + ASSERT_EQ(val->dilateW, 1); + ASSERT_EQ(val->padMode, schema::PadMode_SAME); + ASSERT_EQ(val->padUp, 1); + ASSERT_EQ(val->padDown, 1); + ASSERT_EQ(val->padLeft, 1); + ASSERT_EQ(val->padRight, 1); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_fill_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_fill_parser_test.cc index 359d54b6e54..4ce37d77a46 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_fill_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_fill_parser_test.cc @@ -25,17 +25,15 @@ class TestTfliteParserFill : public TestTfliteParser { }; TEST_F(TestTfliteParserFill, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Fill) << "wrong Op Type"; } -TEST_F(TestTfliteParserFill, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - +TEST_F(TestTfliteParserFill, AttrValue) {; + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsFill(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsFill(); - std::vector dims = {9}; ASSERT_EQ(val->dims, dims); } diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_nd_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_nd_parser_test.cc index d6badec6a72..5c4af9e7f91 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_nd_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_nd_parser_test.cc @@ -25,15 +25,14 @@ class TestTfliteParserGatherNd : public TestTfliteParser { }; TEST_F(TestTfliteParserGatherNd, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_GatherNd) << "wrong Op Type"; } TEST_F(TestTfliteParserGatherNd, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGatherNd(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsGatherNd(); ASSERT_EQ(val->batchDims, 0); } diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_parser_test.cc index 13420abadac..071738a15a7 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_gather_parser_test.cc @@ -25,15 +25,14 @@ class TestTfliteParserGather : public TestTfliteParser { }; TEST_F(TestTfliteParserGather, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Gather) << "wrong Op Type"; } TEST_F(TestTfliteParserGather, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsGather(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsGather(); ASSERT_EQ(val->axis, 0); ASSERT_EQ(val->batchDims, 0); diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_lrn_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_lrn_parser_test.cc index 88a504a6aef..8db8351d4b1 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_lrn_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_lrn_parser_test.cc @@ -25,6 +25,7 @@ class TestTfliteParserLRN : public TestTfliteParser { }; TEST_F(TestTfliteParserLRN, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, @@ -32,9 +33,7 @@ TEST_F(TestTfliteParserLRN, OpType) { } TEST_F(TestTfliteParserLRN, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsLocalResponseNormalization(); ASSERT_EQ(val->alpha, 1); ASSERT_EQ(val->beta, 0.5); diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_one_hot_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_one_hot_parser_test.cc index ed3d946a967..2c7a40b910c 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_one_hot_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_one_hot_parser_test.cc @@ -32,12 +32,9 @@ TEST_F(TestTfliteParserOneHot, OpType) { } TEST_F(TestTfliteParserOneHot, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsOneHot(), nullptr); - // in OneHot parser axis = axis > 0 ? axis : axis + tensor_shape.size() - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsOneHot()->axis, 2); + auto val = meta_graph->nodes.front()->primitive->value.AsOneHot(); + ASSERT_EQ(val->axis, 2); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pad_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pad_parser_test.cc index ada46577596..1d33e1a8fc8 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pad_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pad_parser_test.cc @@ -25,17 +25,15 @@ class TestTfliteParserPad : public TestTfliteParser { }; TEST_F(TestTfliteParserPad, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Pad) << "wrong Op Type"; } TEST_F(TestTfliteParserPad, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPad(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsPad(); - std::vector paddings = {1, 1, 2, 2, 3, 3, 4, 4}; ASSERT_EQ(val->paddings, paddings); } diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pooling_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pooling_parser_test.cc index d234a76a323..99d2500d598 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pooling_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_pooling_parser_test.cc @@ -35,12 +35,8 @@ TEST_F(TestTfliteParserMaxPooling, OpType) { } TEST_F(TestTfliteParserMaxPooling, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPooling(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsPooling(); - ASSERT_NE(val, nullptr); ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->poolingMode, schema::PoolMode_MAX_POOLING); ASSERT_EQ(val->global, false); @@ -72,12 +68,8 @@ TEST_F(TestTfliteParserAvgPooling, OpType) { } TEST_F(TestTfliteParserAvgPooling, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsPooling(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsPooling(); - ASSERT_NE(val, nullptr); ASSERT_EQ(val->format, schema::Format_NHWC); ASSERT_EQ(val->poolingMode, schema::PoolMode_MEAN_POOLING); ASSERT_EQ(val->global, false); diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_parser_test.cc index bf790bd842f..86928b867be 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reduce_parser_test.cc @@ -32,13 +32,9 @@ TEST_F(TestTfliteParserReduceMax, OpType) { } TEST_F(TestTfliteParserReduceMax, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); - ASSERT_NE(val, nullptr); - ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMax) << "wrong reduce mode"; + ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMax); ASSERT_EQ(val->keepDims, false); std::vector axes = {2}; ASSERT_EQ(val->axes, axes); @@ -58,13 +54,9 @@ TEST_F(TestTfliteParserReduceMin, OpType) { } TEST_F(TestTfliteParserReduceMin, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); - ASSERT_NE(val, nullptr); - ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMin) << "wrong reduce mode"; + ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMin); ASSERT_EQ(val->keepDims, false); std::vector axes = {2}; ASSERT_EQ(val->axes, axes); @@ -84,13 +76,9 @@ TEST_F(TestTfliteParserReduceProd, OpType) { } TEST_F(TestTfliteParserReduceProd, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); - ASSERT_NE(val, nullptr); - ASSERT_EQ(val->mode, schema::ReduceMode_ReduceProd) << "wrong reduce mode"; + ASSERT_EQ(val->mode, schema::ReduceMode_ReduceProd); ASSERT_EQ(val->keepDims, false); std::vector axes = {2}; ASSERT_EQ(val->axes, axes); @@ -111,13 +99,9 @@ TEST_F(TestTfliteParserSum, OpType) { } TEST_F(TestTfliteParserSum, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); - ASSERT_NE(val, nullptr); - ASSERT_EQ(val->mode, schema::ReduceMode_ReduceSum) << "wrong reduce mode"; + ASSERT_EQ(val->mode, schema::ReduceMode_ReduceSum); ASSERT_EQ(val->keepDims, false); std::vector axes = {2}; ASSERT_EQ(val->axes, axes); @@ -138,13 +122,9 @@ TEST_F(TestTfliteParserMean, OpType) { } TEST_F(TestTfliteParserMean, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReduce(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsReduce(); - ASSERT_NE(val, nullptr); - ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMean) << "wrong reduce mode"; + ASSERT_EQ(val->mode, schema::ReduceMode_ReduceMean); ASSERT_EQ(val->keepDims, true); std::vector axes = {2, 3}; ASSERT_EQ(val->axes, axes); diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reshape_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reshape_parser_test.cc index a4eef7898b3..3d29e9c1929 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reshape_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reshape_parser_test.cc @@ -35,12 +35,9 @@ TEST_F(TestTfliteParserReshape, OpType) { } TEST_F(TestTfliteParserReshape, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReshape(), nullptr); - + auto val = meta_graph->nodes.front()->primitive->value.AsReshape(); std::vector shape = {3, 5, 20}; - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReshape()->shape, shape); // int32 + ASSERT_EQ(val->shape, shape); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_resize_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_resize_parser_test.cc index ab0c9a9a256..cfff0c88ae5 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_resize_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_resize_parser_test.cc @@ -26,17 +26,15 @@ class TestTfliteParserResizeNN : public TestTfliteParser { }; TEST_F(TestTfliteParserResizeNN, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type"; } TEST_F(TestTfliteParserResizeNN, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsResize(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsResize(); - ASSERT_NE(val, nullptr); ASSERT_EQ(val->alignCorners, false); ASSERT_EQ(val->newHeight, 3); ASSERT_EQ(val->newWidth, 100); @@ -52,17 +50,15 @@ class TestTfliteParserResizeBilinear : public TestTfliteParser { }; TEST_F(TestTfliteParserResizeBilinear, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Resize) << "wrong Op Type"; } TEST_F(TestTfliteParserResizeBilinear, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsResize(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsResize(); - ASSERT_NE(val, nullptr); ASSERT_EQ(val->alignCorners, false); ASSERT_EQ(val->newHeight, 75); ASSERT_EQ(val->newWidth, 4); diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_parser_test.cc index ea1ffff935c..e4a03440baa 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_parser_test.cc @@ -25,17 +25,15 @@ class TestTfliteParserReverse : public TestTfliteParser { }; TEST_F(TestTfliteParserReverse, OpType) { + ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Reverse) << "wrong Op Type"; } TEST_F(TestTfliteParserReverse, AttrValue) { - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverse(), nullptr); auto val = meta_graph->nodes.front()->primitive->value.AsReverse(); - std::vector axis = {3}; ASSERT_EQ(val->axis, axis); } diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_sequence_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_sequence_parser_test.cc index a5f3e58d997..ba5b7c52208 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_sequence_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_reverse_sequence_parser_test.cc @@ -35,13 +35,11 @@ TEST_F(TestTfliteParserReverseSequence, OpType) { } TEST_F(TestTfliteParserReverseSequence, AttrValue) { - std::vector seq_length{7, 2, 3, 5}; - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsReverseSequence(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqAxis, 1); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsReverseSequence()->seqLengths, seq_length); + auto val = meta_graph->nodes.front()->primitive->value.AsReverseSequence(); + ASSERT_EQ(val->seqAxis, 1); + ASSERT_EQ(val->seqAxis, 1); + std::vector seq_length = {7, 2, 3, 5}; + ASSERT_EQ(val->seqLengths, seq_length); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_slice_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_slice_parser_test.cc new file mode 100644 index 00000000000..655f114eeda --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_slice_parser_test.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserSlice : public TestTfliteParser { + public: + TestTfliteParserSlice() = default; + + void SetUp() override { meta_graph = LoadAndConvert("./slice.tflite"); } +}; + +TEST_F(TestTfliteParserSlice, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Slice) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserSlice, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSlice(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsSlice(); + ASSERT_EQ(val->format, schema::Format_NHWC); + std::vector begin = {1, 0, 0}; + ASSERT_EQ(val->begin, begin); + std::vector size = {1, 1, 3}; + ASSERT_EQ(val->size, size); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_softmax_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_softmax_parser_test.cc index c6fb258e6ee..88488f946f4 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_softmax_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_softmax_parser_test.cc @@ -35,11 +35,9 @@ TEST_F(TestTfliteParserSoftmax, OpType) { } TEST_F(TestTfliteParserSoftmax, AttrValue) { -ASSERT_NE(meta_graph, nullptr); -ASSERT_GT(meta_graph->nodes.size(), 0); -ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); -ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSoftMax(), nullptr); -ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSoftMax()->axis, -1); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSoftMax(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsSoftMax(); + ASSERT_EQ(val->axis, -1); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser_test.cc index dab7110c1af..ae80d1481a1 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser_test.cc @@ -35,13 +35,11 @@ TEST_F(TestTfliteParserSpaceToBatchND, OpType) { } TEST_F(TestTfliteParserSpaceToBatchND, AttrValue) { - std::vector blockshape{2, 2}; - std::vector padding{0, 0, 2, 0}; - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->blockShape, blockshape); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND()->paddings, padding); + auto val = meta_graph->nodes.front()->primitive->value.AsSpaceToBatchND(); + std::vector blockshape = {2, 2}; + ASSERT_EQ(val->blockShape, blockshape); + std::vector padding = {0, 0, 2, 0}; + ASSERT_EQ(val->paddings, padding); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_depth_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_depth_parser_test.cc index 785baff5177..d1ed72bee24 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_depth_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_space_to_depth_parser_test.cc @@ -35,11 +35,9 @@ TEST_F(TestTfliteParserSpaceToDepth, OpType) { } TEST_F(TestTfliteParserSpaceToDepth, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->blockSize, 2); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSpaceToDepth()->format, schema::Format_NHWC); + auto val = meta_graph->nodes.front()->primitive->value.AsSpaceToDepth(); + ASSERT_EQ(val->blockSize, 2); + ASSERT_EQ(val->format, schema::Format_NHWC); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sparse_to_dense_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sparse_to_dense_parser_test.cc index 93e7badacf1..c6b7fc1c1da 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sparse_to_dense_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_sparse_to_dense_parser_test.cc @@ -35,16 +35,14 @@ TEST_F(TestTfliteParserSparseToDense, OpType) { } TEST_F(TestTfliteParserSparseToDense, AttrValue) { - std::vector outputShape{5, 5}; - std::vector sparseValue{1}; - std::vector defaultValue{0}; - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSparseToDense(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->outputShape, outputShape); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->sparseValue, sparseValue); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->defaultValue, defaultValue); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSparseToDense()->validateIndices, false); + auto val = meta_graph->nodes.front()->primitive->value.AsSparseToDense(); + std::vector outputShape = {5, 5}; + ASSERT_EQ(val->outputShape, outputShape); + std::vector sparseValue = {1}; + ASSERT_EQ(val->sparseValue, sparseValue); + std::vector defaultValue = {0}; + ASSERT_EQ(val->defaultValue, defaultValue); + ASSERT_EQ(val->validateIndices, false); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_parser_test.cc index b7c9e38cb91..97cb01d9995 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_parser_test.cc @@ -33,14 +33,12 @@ TEST_F(TestTfliteParserSplit, OpType) { } TEST_F(TestTfliteParserSplit, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr); - const std::vector sizeSplits{2, 2}; - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 2); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits); + auto val = meta_graph->nodes.front()->primitive->value.AsSplit(); + ASSERT_EQ(val->splitDim, 2); + ASSERT_EQ(val->numberSplit, 2); + const std::vector sizeSplits = {2, 2}; + ASSERT_EQ(val->sizeSplits, sizeSplits); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_v_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_v_parser_test.cc index 76eb5642470..b0c6e781059 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_v_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_split_v_parser_test.cc @@ -33,14 +33,12 @@ TEST_F(TestTfliteParserSplitV, OpType) { } TEST_F(TestTfliteParserSplitV, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsSplit(), nullptr); - const std::vector sizeSplits{1, 3}; - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->splitDim, 0); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->numberSplit, 2); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsSplit()->sizeSplits, sizeSplits); + auto val = meta_graph->nodes.front()->primitive->value.AsSplit(); + ASSERT_EQ(val->splitDim, 0); + ASSERT_EQ(val->numberSplit, 2); + const std::vector sizeSplits = {1, 3}; + ASSERT_EQ(val->sizeSplits, sizeSplits); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_stack_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_stack_parser_test.cc new file mode 100644 index 00000000000..ff6d01841ac --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_stack_parser_test.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserStack : public TestTfliteParser { + public: + TestTfliteParserStack() = default; + + void SetUp() override { meta_graph = LoadAndConvert("./stack.tflite"); } +}; + +TEST_F(TestTfliteParserStack, OpType) { + ASSERT_NE(meta_graph, nullptr); + ASSERT_GT(meta_graph->nodes.size(), 0); + ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Stack) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserStack, AttrValue) { + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStack(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsStack(); + ASSERT_EQ(val->axis, 1); + ASSERT_EQ(val->n, 2); + const std::vector isScale = {3, 2, 3}; + ASSERT_EQ(val->isScale, isScale); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_strided_slice_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_strided_slice_parser_test.cc index 88177cd9aa2..ef6fa94a7a6 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_strided_slice_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_strided_slice_parser_test.cc @@ -35,21 +35,19 @@ TEST_F(TestTfliteParserStridedSlice, OpType) { } TEST_F(TestTfliteParserStridedSlice, AttrValue) { - std::vector begin{1, -1, 0}; - std::vector end{2, -3, 3}; - std::vector stride{1, -1, 1}; - std::vector isscale{3, 2, 3}; - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsStridedSlice(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->endMask, 0); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->beginMask, 0); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->begin, begin); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->end, end); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->stride, stride); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsStridedSlice()->isScale, isscale); + auto val = meta_graph->nodes.front()->primitive->value.AsStridedSlice(); + ASSERT_EQ(val->beginMask, 0); + ASSERT_EQ(val->endMask, 0); + ASSERT_EQ(val->beginMask, 0); + ASSERT_EQ(val->beginMask, 0); + std::vector begin = {1, -1, 0}; + ASSERT_EQ(val->begin, begin); + std::vector end = {2, -3, 3}; + ASSERT_EQ(val->end, end); + std::vector stride = {1, -1, 1}; + ASSERT_EQ(val->stride, stride); + std::vector isscale = {3, 2, 3}; + ASSERT_EQ(val->isScale, isscale); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_tile_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_tile_parser_test.cc index e025bad6f95..fe4d930acd4 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_tile_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_tile_parser_test.cc @@ -35,11 +35,9 @@ TEST_F(TestTfliteParserTile, OpType) { } TEST_F(TestTfliteParserTile, AttrValue) { - std::vector multiply{2, 3, 4}; - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTile(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTile()->multiples, multiply); + auto val = meta_graph->nodes.front()->primitive->value.AsTile(); + std::vector multiply = {2, 3, 4}; + ASSERT_EQ(val->multiples, multiply); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc index 68e36296119..a2623e83901 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc @@ -35,13 +35,10 @@ TEST_F(TestTfliteParserTopKV2, OpType) { } TEST_F(TestTfliteParserTopKV2, AttrValue) { - // attr->sorted default is true - std::vector k{3}; - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->k, k); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsTopKV2()->sorted, true); + auto val = meta_graph->nodes.front()->primitive->value.AsTopKV2(); + std::vector k = {3}; + ASSERT_EQ(val->k, k); + ASSERT_EQ(val->sorted, true); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_transpose_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_transpose_parser_test.cc new file mode 100644 index 00000000000..f5891da3bf8 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_transpose_parser_test.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ut/tools/converter/parser/tflite/tflite_parsers_test_utils.h" +#include +#include "common/common_test.h" + +namespace mindspore { +class TestTfliteParserTranspose : public TestTfliteParser { + public: + TestTfliteParserTranspose() = default; + + void SetUp() override { meta_graph = LoadAndConvert("./transpose.tflite"); } +}; + +TEST_F(TestTfliteParserTranspose, OpType) { +ASSERT_NE(meta_graph, nullptr); +ASSERT_GT(meta_graph->nodes.size(), 0); +ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); +ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_Transpose) << "wrong Op Type"; +} + +TEST_F(TestTfliteParserTranspose, AttrValue) { +ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTranspose(), nullptr); +auto val = meta_graph->nodes.front()->primitive->value.AsTranspose(); +ASSERT_EQ(val->conjugate, false); +std::vector perm = {1, 0}; +ASSERT_EQ(val->perm, perm); +} + +} // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unique_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unique_parser_test.cc index 5b49883b781..0273adbfe6e 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unique_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unique_parser_test.cc @@ -35,10 +35,9 @@ TEST_F(TestTfliteParserUnique, OpType) { } TEST_F(TestTfliteParserUnique, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnique()->outType, 34); // int32 + auto val = meta_graph->nodes.front()->primitive->value.AsUnique(); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnique(), nullptr); + ASSERT_EQ(val->outType, 34); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unstack_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unstack_parser_test.cc index 9c8e59232e9..44c020d2c63 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unstack_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_unstack_parser_test.cc @@ -35,11 +35,9 @@ TEST_F(TestTfliteParserUnstack, OpType) { } TEST_F(TestTfliteParserUnstack, AttrValue) { - ASSERT_NE(meta_graph, nullptr); - ASSERT_GT(meta_graph->nodes.size(), 0); - ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsUnstack(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->num, 5); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.AsUnstack()->axis, 1); + auto val = meta_graph->nodes.front()->primitive->value.AsUnstack(); + ASSERT_EQ(val->num, 5); + ASSERT_EQ(val->axis, 1); } } // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc index a87e960c8c4..c863c502a8a 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc @@ -353,7 +353,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { status = TransFilterFormat(weightTensor.get(), kCKHW2KHWC); } else if (weightTensor->format == schema::Format_KCHW) { status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); - } else if (weightTensor->format == schema::Format_CHWK) { + } else if (weightTensor->format == schema::Format_CHWK) { // from tflite status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); } else { MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; @@ -369,7 +369,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); - } else if (weightTensor->format == schema::Format_CHWK) { // from tf + } else if (weightTensor->format == schema::Format_CHWK) { // from tflite status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); } else { MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_flatten_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_flatten_parser.cc index 9eb1f050952..f24264cd584 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_flatten_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_flatten_parser.cc @@ -21,11 +21,16 @@ namespace lite { STATUS CaffeFlattenParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, std::vector *weightVec) { if (op == nullptr) { - // MS_LOGE("null pointer dereferencing."); + // MS_LOG(ERROR) << "null pointer dereferencing."; return RET_NULL_PTR; } - std::unique_ptr attr(new schema::ReshapeT()); - attr->format = schema::Format_NCHW; + std::unique_ptr attr(new schema::FullConnectionT()); + const caffe::FlattenParameter flattenParam = proto.flatten_param(); + + attr->axis = (int32_t)flattenParam.axis(); + attr->useAxis = true; + attr->hasBias = false; + attr->activationType = schema::ActivationType_NO_ACTIVATION; op->primitive = std::make_unique(); op->primitive->value.type = schema::PrimitiveType_Flatten; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc index 3ffb6ef901a..50fb64f71d4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc @@ -14,18 +14,21 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_activation_parser.h" #include #include #include -#include "tools/converter/parser/tflite/tflite_activation_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteActivationParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteActivationParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,13 +38,11 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr &t MS_LOG(ERROR) << "op->primitive is null"; return RET_NULL_PTR; } - std::unique_ptr attr(new schema::ActivationT()); std::vector node_name_str; Split(op->name, &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); - if (std::strcmp(node_name, "Relu") == 0) { MS_LOG(DEBUG) << "parse TfliteReluParser"; attr->type = schema::ActivationType_RELU; @@ -54,29 +55,65 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr &t } else if (std::strcmp(node_name, "Logistic") == 0) { MS_LOG(DEBUG) << "parse TfliteLogisticParser"; attr->type = schema::ActivationType_SIGMOID; - } else if (std::strcmp(node_name, "LeakyRelu") == 0) { - const auto &option = tfliteOp->builtin_options.AsLeakyReluOptions(); - if (option == nullptr) { - MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; - return RET_NULL_PTR; - } - attr->type = schema::ActivationType_LEAKY_RELU; - attr->alpha = option->alpha; - } else { - MS_LOG(ERROR) << "wrong activation type"; - return RET_ERROR; + } else if (std::strcmp(node_name, "HardSwish") == 0) { + MS_LOG(DEBUG) << "parse TfliteHardSwishParser"; + attr->type = schema::ActivationType_SIGMOID; } + attr->alpha = 0.2f; op->primitive->value.type = schema::PrimitiveType_Activation; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } STATUS TflitePreluParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) { + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TflitePreluParser"; + + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + std::unique_ptr attr(new schema::PreluT()); + + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) { + MS_LOG(ERROR) << "get pRelu -> slope failed"; + return RET_ERROR; + } + op->primitive->value.type = schema::PrimitiveType_Prelu; + op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + return RET_OK; +} + +STATUS TfliteLeakyReluParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteLeakyReluParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -87,22 +124,29 @@ STATUS TflitePreluParser::Parse(const std::unique_ptr &tflite return RET_NULL_PTR; } - MS_LOG(DEBUG) << "paser TflitePreluParser"; - std::unique_ptr attr(new schema::PreluT()); + std::unique_ptr attr(new schema::LeakyReLUT()); - if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->slope)) { - MS_LOG(ERROR) << "get pRelu -> slope failed"; - return RET_ERROR; + const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + return RET_NULL_PTR; } + attr->negativeSlope = tflite_attr->alpha; - op->primitive->value.type = schema::PrimitiveType_Prelu; + op->primitive->value.type = schema::PrimitiveType_LeakyReLU; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser()); TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser()); TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser()); +TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser()); TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser()); TfliteNodeRegister g_tflitePreluParser("Prelu", new TflitePreluParser()); TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h index 0c4b3509321..0223bcffac6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h @@ -14,13 +14,14 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_RELU_PARSER_H -#define PREDICT_TFLITE_RELU_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H -#include "tools/converter/parser/tflite/tflite_node_parser.h" -#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" #include #include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" namespace mindspore { namespace lite { @@ -29,11 +30,13 @@ class TfliteActivationParser : public TfliteNodeParser { public: TfliteActivationParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteReluParser : public TfliteActivationParser { @@ -56,9 +59,9 @@ class TfliteLogisticParser : public TfliteActivationParser { TfliteLogisticParser() : TfliteActivationParser() {} }; -class TfliteLeakyReluParser : public TfliteActivationParser { +class TfliteHardSwishParser : public TfliteActivationParser { public: - TfliteLeakyReluParser() : TfliteActivationParser() {} + TfliteHardSwishParser() : TfliteActivationParser() {} }; class TflitePreluParser : public TfliteNodeParser { @@ -68,12 +71,27 @@ class TflitePreluParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; +}; + +class TfliteLeakyReluParser : public TfliteNodeParser { + public: + TfliteLeakyReluParser() : TfliteNodeParser("LeakyRelu") {} + + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_RELU_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc index cf518ceafbb..b88666fca93 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc @@ -18,14 +18,20 @@ #include "tools/converter/parser/tflite/tflite_addn_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteAddNParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteAddNParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteAddNParser"; + + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -36,13 +42,19 @@ STATUS TfliteAddNParser::Parse(const std::unique_ptr &tfliteO return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteAddNParser"; std::unique_ptr attr(new schema::AddNT()); - attr->N = tfliteTensors.size() - 1; - + attr->N = tflite_tensors.size() - 1; op->primitive->value.type = schema::PrimitiveType_AddN; op->primitive->value.value = attr.release(); + + // set input + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h index bdc51bfc48b..8bd1ef03acb 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_ADDN_PARSER_H -#define LITE_TFLITE_ADDN_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteAddNParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_ADDN_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADDN_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc index b3fe1143032..1ccf2ea4651 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc @@ -17,16 +17,19 @@ #include "tools/converter/parser/tflite/tflite_argmax_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteArgmaxParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, +STATUS TfliteArgmaxParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,7 +40,6 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; std::unique_ptr attr(new schema::ArgMaxT()); attr->outMaxValue = false; @@ -45,9 +47,10 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr &tflit attr->keepDims = false; attr->axisType = 1; - auto axis_idx = tfliteOp->inputs[1]; - std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){}); - auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer]; + // get axis attr + auto axis_idx = tflite_op->inputs[1]; + std::for_each(tflite_tensors[axis_idx]->shape.begin(), tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha){}); + auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer]; if (buf_data == nullptr) { MS_LOG(ERROR) << "the buf data is null"; return RET_NULL_PTR; @@ -61,6 +64,11 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_ArgMax; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h index 0665a6b0289..f7dc10cfaf9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_ARGMAX_PARSER_H -#define PREDICT_TFLITE_ARGMAX_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,15 @@ class TfliteArgmaxParser : public TfliteNodeParser { public: TfliteArgmaxParser() : TfliteNodeParser("Argmax") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_ARGMAX_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMAX_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc index 37a47c0ea35..1ce77d5ebac 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc @@ -17,14 +17,19 @@ #include "tools/converter/parser/tflite/tflite_argmin_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteArgminParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteArgminParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteArgminParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,7 +40,6 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteArgminParser"; std::unique_ptr attr(new schema::ArgMinT()); attr->outMaxValue = false; @@ -43,9 +47,11 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr &tflit attr->keepDims = false; attr->axisType = 1; - auto axis_idx = tfliteOp->inputs[1]; - std::for_each(tfliteTensors[axis_idx]->shape.begin(), tfliteTensors[axis_idx]->shape.end(), [&](int32_t sha){}); - auto &buf_data = tfliteModelBuffer[tfliteTensors[axis_idx]->buffer]; + // get axis attr + auto axis_idx = tflite_op->inputs[1]; + std::for_each(tflite_tensors[axis_idx]->shape.begin(), + tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha){}); + auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer]; if (buf_data == nullptr) { MS_LOG(ERROR) << "the buf data is null"; return RET_NULL_PTR; @@ -59,6 +65,11 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_ArgMin; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h index a02d4fe5e2a..4213fc32117 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_ARGMIN_PARSER_H -#define PREDICT_TFLITE_ARGMIN_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,15 @@ class TfliteArgminParser : public TfliteNodeParser { public: TfliteArgminParser() : TfliteNodeParser("Argmin") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_ARGMIN_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARGMIN_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc index 99eeb7f860b..8a7eb162366 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc @@ -18,14 +18,17 @@ #include #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,124 +40,72 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr } std::vector node_name_str; - Split(op->name.data(), &node_name_str, "-"); + Split(op->name, &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); - - if (std::strcmp(node_name, "Add") == 0 - || std::strcmp(node_name, "Sub") == 0 - || std::strcmp(node_name, "Mul") == 0 - || std::strcmp(node_name, "Div") == 0) { - auto x_index = tfliteOp->inputs[0]; - const auto &x_tensor = tfliteTensors[x_index]; - if (x_tensor == nullptr) { - MS_LOG(ERROR) << "the first input is null"; + if (std::strcmp(node_name, "Add") == 0) { + MS_LOG(DEBUG) << "parse TfliteAddParser"; + std::unique_ptr attr(new schema::AddT()); + const auto &tfliteAttr = tflite_op->builtin_options.AsAddOptions(); + if (nullptr == tfliteAttr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } - auto &x_data = tfliteModelBuffer.at(x_tensor->buffer); - if (x_data == nullptr) { - MS_LOG(ERROR) << "the data of the first input is null"; + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + op->primitive->value.type = schema::PrimitiveType_Add; + op->primitive->value.value = attr.release(); + } else if (std::strcmp(node_name, "Sub") == 0) { + MS_LOG(DEBUG) << "parse TfliteSubParser"; + std::unique_ptr attr(new schema::SubT()); + const auto &tfliteAttr = tflite_op->builtin_options.AsSubOptions(); + if (nullptr == tfliteAttr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } - if (!x_data->data.empty()) { - std::vector x_tensors{x_tensor.get()}; - if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse the first tensor failed"; - return RET_ERROR; - } - } - - auto y_index = tfliteOp->inputs[1]; - const auto &y_tensor = tfliteTensors[y_index]; - if (y_tensor == nullptr) { - MS_LOG(ERROR) << "the second input is null"; + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + op->primitive->value.type = schema::PrimitiveType_Sub; + op->primitive->value.value = attr.release(); + } else if (std::strcmp(node_name, "Mul") == 0) { + MS_LOG(DEBUG) << "parse TfliteMulParser"; + std::unique_ptr attr(new schema::MulT()); + const auto &tfliteAttr = tflite_op->builtin_options.AsMulOptions(); + if (nullptr == tfliteAttr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } - auto &y_data = tfliteModelBuffer.at(y_tensor->buffer); - if (y_data == nullptr) { - MS_LOG(ERROR) << "the data of the second input is null"; + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + op->primitive->value.type = schema::PrimitiveType_Mul; + op->primitive->value.value = attr.release(); + } else if (std::strcmp(node_name, "Div") == 0) { + MS_LOG(DEBUG) << "parse TfliteDivParser"; + std::unique_ptr attr(new schema::DivT()); + const auto &tfliteAttr = tflite_op->builtin_options.AsDivOptions(); + if (nullptr == tfliteAttr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } - if (!y_data->data.empty()) { - std::vector y_tensors{y_tensor.get()}; - if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse the second tensor failed"; - return RET_ERROR; - } - } - - if (std::strcmp(node_name, "Add") == 0) { - MS_LOG(DEBUG) << "parse TfliteAddParser"; - std::unique_ptr attr(new schema::AddT()); - const auto &tfliteAttr = tfliteOp->builtin_options.AsAddOptions(); - if (nullptr == tfliteAttr) { - MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; - return RET_NULL_PTR; - } - attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); - op->primitive->value.type = schema::PrimitiveType_Add; - op->primitive->value.value = attr.release(); - return RET_OK; - } else if (std::strcmp(node_name, "Sub") == 0) { - MS_LOG(DEBUG) << "parse TfliteSubParser"; - std::unique_ptr attr(new schema::SubT()); - const auto &tfliteAttr = tfliteOp->builtin_options.AsSubOptions(); - if (nullptr == tfliteAttr) { - MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; - return RET_NULL_PTR; - } - attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); - op->primitive->value.type = schema::PrimitiveType_Sub; - op->primitive->value.value = attr.release(); - return RET_OK; - } else if (std::strcmp(node_name, "Mul") == 0) { - MS_LOG(DEBUG) << "parse TfliteMulParser"; - std::unique_ptr attr(new schema::MulT()); - const auto &tfliteAttr = tfliteOp->builtin_options.AsMulOptions(); - if (nullptr == tfliteAttr) { - MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; - return RET_NULL_PTR; - } - attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); - op->primitive->value.type = schema::PrimitiveType_Mul; - op->primitive->value.value = attr.release(); - return RET_OK; - } else if (std::strcmp(node_name, "Div") == 0) { - MS_LOG(DEBUG) << "parse TfliteDivParser"; - std::unique_ptr attr(new schema::DivT()); - const auto &tfliteAttr = tfliteOp->builtin_options.AsDivOptions(); - if (nullptr == tfliteAttr) { - MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; - return RET_NULL_PTR; - } - attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); - op->primitive->value.type = schema::PrimitiveType_Div; - op->primitive->value.value = attr.release(); - return RET_OK; - } + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + op->primitive->value.type = schema::PrimitiveType_Div; + op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "FloorDiv") == 0) { MS_LOG(DEBUG) << "parse TfliteFloorDivParser"; std::unique_ptr attr(new schema::FloorDivT()); op->primitive->value.type = schema::PrimitiveType_FloorDiv; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "FloorMod") == 0) { MS_LOG(DEBUG) << "parse TfliteFloorModParser"; std::unique_ptr attr(new schema::FloorModT()); op->primitive->value.type = schema::PrimitiveType_FloorMod; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "RealDiv") == 0) { MS_LOG(DEBUG) << "parse TfliteRealDivParser"; std::unique_ptr attr(new schema::RealDivT()); - op->primitive->value.type = schema::PrimitiveType_RealDiv; + op->primitive->value.type = schema::PrimitiveType_Div; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "SquaredDifference") == 0) { MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser"; std::unique_ptr attr(new schema::SquaredDifferenceT()); op->primitive->value.type = schema::PrimitiveType_SquaredDifference; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Pow") == 0) { MS_LOG(DEBUG) << "parse TflitePowParser"; std::unique_ptr attr(new schema::PowerT()); @@ -163,31 +114,35 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr attr->shift = 0.0f; op->primitive->value.type = schema::PrimitiveType_Power; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Maximum") == 0) { MS_LOG(DEBUG) << "parse TfliteMaximumParser"; std::unique_ptr attr(new schema::MaximumT()); op->primitive->value.type = schema::PrimitiveType_Maximum; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Minimum") == 0) { MS_LOG(DEBUG) << "parse TfliteMinimumParser"; std::unique_ptr attr(new schema::MinimumT()); op->primitive->value.type = schema::PrimitiveType_Minimum; op->primitive->value.value = attr.release(); - return RET_OK; - } else { - MS_LOG(ERROR) << "wrong op type"; - return RET_ERROR; } + + // set input + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } -STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -199,85 +154,79 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr } std::vector node_name_str; - Split(op->name.data(), &node_name_str, "-"); + Split(op->name, &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); if (std::strcmp(node_name, "Abs") == 0) { MS_LOG(DEBUG) << "parse TfliteAbsParser"; std::unique_ptr attr(new schema::AbsT()); op->primitive->value.type = schema::PrimitiveType_Abs; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Exp") == 0) { MS_LOG(DEBUG) << "parse TfliteExpParser"; std::unique_ptr attr(new schema::ExpT()); op->primitive->value.type = schema::PrimitiveType_Exp; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Sqrt") == 0) { MS_LOG(DEBUG) << "parse TfliteSqrtParser"; std::unique_ptr attr(new schema::SqrtT()); op->primitive->value.type = schema::PrimitiveType_Sqrt; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Rsqrt") == 0) { MS_LOG(DEBUG) << "parse TfliteRsqrtParser"; std::unique_ptr attr(new schema::RsqrtT()); op->primitive->value.type = schema::PrimitiveType_Rsqrt; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Square") == 0) { MS_LOG(DEBUG) << "parse TfliteSquareParser"; std::unique_ptr attr(new schema::SquareT()); op->primitive->value.type = schema::PrimitiveType_Square; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Sin") == 0) { MS_LOG(DEBUG) << "parse TfliteSinParser"; std::unique_ptr attr(new schema::SinT()); op->primitive->value.type = schema::PrimitiveType_Sin; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Cos") == 0) { MS_LOG(DEBUG) << "parse TfliteCosParser"; std::unique_ptr attr(new schema::CosT()); op->primitive->value.type = schema::PrimitiveType_Cos; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Log") == 0) { MS_LOG(DEBUG) << "parse TfliteLogParser"; std::unique_ptr attr(new schema::LogT()); op->primitive->value.type = schema::PrimitiveType_Log; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Round") == 0) { MS_LOG(DEBUG) << "parse TfliteRoundParser"; std::unique_ptr attr(new schema::RoundT()); op->primitive->value.type = schema::PrimitiveType_Round; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Ceil") == 0) { MS_LOG(DEBUG) << "parse TfliteCeilParser"; std::unique_ptr attr(new schema::CeilT()); op->primitive->value.type = schema::PrimitiveType_Ceil; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "flOOR") == 0) { MS_LOG(DEBUG) << "parse TfliteFloorParser"; std::unique_ptr attr(new schema::FloorT()); op->primitive->value.type = schema::PrimitiveType_Floor; op->primitive->value.value = attr.release(); - return RET_OK; - } else { - MS_LOG(ERROR) << "wrong op type"; - return RET_ERROR; } + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + return RET_OK; } -STATUS TfliteCompareOpParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteCompareOpParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -289,48 +238,47 @@ STATUS TfliteCompareOpParser::Parse(const std::unique_ptr &tf } std::vector node_name_str; - Split(op->name.data(), &node_name_str, "-"); + Split(op->name, &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); if (std::strcmp(node_name, "Equal") == 0) { MS_LOG(DEBUG) << "parse TfliteEqualParser"; std::unique_ptr attr(new schema::EqualT()); op->primitive->value.type = schema::PrimitiveType_Equal; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "NotEqual") == 0) { MS_LOG(DEBUG) << "parse TfliteNotEqualParser"; std::unique_ptr attr(new schema::NotEqualT()); op->primitive->value.type = schema::PrimitiveType_NotEqual; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Greater") == 0) { MS_LOG(DEBUG) << "parse TfliteGreaterParser"; std::unique_ptr attr(new schema::GreaterT()); op->primitive->value.type = schema::PrimitiveType_Greater; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "GreaterEqual") == 0) { MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser"; std::unique_ptr attr(new schema::GreaterEqualT()); op->primitive->value.type = schema::PrimitiveType_GreaterEqual; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "Less") == 0) { MS_LOG(DEBUG) << "parse TfliteLessParser"; std::unique_ptr attr(new schema::LessT()); op->primitive->value.type = schema::PrimitiveType_Less; op->primitive->value.value = attr.release(); - return RET_OK; } else if (std::strcmp(node_name, "LessEqual") == 0) { MS_LOG(DEBUG) << "parse TfliteLessEqualParser"; std::unique_ptr attr(new schema::LessEqualT()); op->primitive->value.type = schema::PrimitiveType_LessEqual; op->primitive->value.value = attr.release(); - return RET_OK; - } else { - MS_LOG(ERROR) << "wrong op type"; - return RET_ERROR; } + + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + return RET_OK; } TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h index 8df29fb87b1..d79da7a58ac 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_MATH_PARSER_H -#define PREDICT_TFLITE_MATH_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -29,11 +30,13 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser { public: TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteAddParser : public TfliteDoubleInputOpParser { @@ -96,11 +99,13 @@ class TfliteSingleInputOpParser : public TfliteNodeParser { public: TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteAbsParser : public TfliteSingleInputOpParser { @@ -163,11 +168,13 @@ class TfliteCompareOpParser : public TfliteNodeParser { public: TfliteCompareOpParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteEqualParser : public TfliteCompareOpParser { @@ -203,5 +210,5 @@ class TfliteLessEqualParser : public TfliteCompareOpParser { } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_MATH_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ARITHMETIC_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc index f42338bcaf5..b1b4d5fad61 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc @@ -19,14 +19,17 @@ #include #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -38,30 +41,32 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr } std::vector node_name_str; - Split(op->name.data(), &node_name_str, "-"); + Split(op->name, &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); if (std::strcmp(node_name, "BatchToSpace") == 0) { MS_LOG(DEBUG) << "parse TfliteBatchToSpaceParser"; } else if (std::strcmp(node_name, "BatchToSpaceND") == 0) { MS_LOG(DEBUG) << "parse TfliteBatchToSpaceNDParser"; - // in tflite - // blockShape should be a 1D tensor with dimension [spatial_dims_num] - // crops should be a 2D tensor with dimension [spatial_dims_num, 2] } std::unique_ptr attr(new schema::BatchToSpaceT()); - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->blockShape)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->blockShape)) { MS_LOG(ERROR) << "get batchToSpace -> blockShape failed"; return RET_ERROR; } - if (GetTfliteData(tfliteOp->inputs[2], tfliteTensors, tfliteModelBuffer, attr->crops)) { + if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->crops)) { MS_LOG(ERROR) << "get batchToSpace -> crops failed"; return RET_ERROR; } op->primitive->value.type = schema::PrimitiveType_BatchToSpace; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h index def11ce9f9a..8e28f3b4cfb 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_BATCH_TO_SPACE_PARSER_H -#define LITE_TFLITE_BATCH_TO_SPACE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,8 +32,10 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { @@ -43,4 +46,4 @@ class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_BATCH_TO_SPACE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_TO_SPACE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc index f73bcc9a871..199aff567b2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc @@ -18,14 +18,19 @@ #include "tools/converter/parser/tflite/tflite_broadcast_to_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteBroadcastToParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -36,16 +41,20 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr & return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteBroadcastToParser"; std::unique_ptr attr(new schema::BroadcastToT()); - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->dst_shape)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dst_shape)) { MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed"; return RET_ERROR; } op->primitive->value.type = schema::PrimitiveType_BroadcastTo; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h index 0bbebd449bc..25478346fc2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_BROADCAST_TO_PARSER_H -#define LITE_TFLITE_BROADCAST_TO_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteBroadcastToParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_BROADCAST_TO_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc index 78fe45fb250..ee5120ddec8 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc @@ -14,18 +14,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "tools/converter/parser/tflite/tflite_cast_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteCastParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteCastParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteCastParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -36,25 +40,28 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr &tfliteO return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteCastParser"; std::unique_ptr attr(new schema::CastT()); - const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]]; + const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]]; if (in_tensor == nullptr) { MS_LOG(ERROR) << "tensor is null"; return RET_NULL_PTR; } - attr->srcT = dtype_map[in_tensor->type]; - - const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]]; + attr->srcT = GetTfliteDataType(in_tensor->type); + const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]]; if (out_tensor == nullptr) { MS_LOG(ERROR) << "tensor is null"; return RET_NULL_PTR; } - attr->dstT = dtype_map[out_tensor->type]; + attr->dstT = GetTfliteDataType(out_tensor->type); op->primitive->value.type = schema::PrimitiveType_Cast; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h index ae1dca284cb..151808dbd5c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_CAST_PARSER_ -#define LITE_TFLITE_CAST_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteCastParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_CAST_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CAST_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc index a930233ceed..86c4d98b89a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc @@ -17,14 +17,20 @@ #include "tools/converter/parser/tflite/tflite_concat_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteConcatParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteConcatParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteConcatParser"; + + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,20 +41,25 @@ STATUS TfliteConcatParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteConcatParser"; std::unique_ptr attr(new schema::ConcatT()); - const auto &tfliteAttr = tfliteOp->builtin_options.AsConcatenationOptions(); + const auto &tfliteAttr = tflite_op->builtin_options.AsConcatenationOptions(); if (tfliteAttr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } attr->axis = tfliteAttr->axis; - - attr->n = tfliteOp->inputs.size(); + attr->n = tflite_op->inputs.size(); op->primitive->value.type = schema::PrimitiveType_Concat; op->primitive->value.value = attr.release(); + + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h index d2a1acff77c..eac2caf581f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_CONCAT_PARSER_H -#define PREDICT_TFLITE_CONCAT_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteConcatParser : public TfliteNodeParser { public: TfliteConcatParser() : TfliteNodeParser("Concat") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_CONCAT_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONCAT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc index f075f10e850..8de26527b25 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc @@ -17,14 +17,19 @@ #include "tools/converter/parser/tflite/tflite_conv_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteConvParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteConvParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteConvParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,60 +40,61 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr &tfliteO return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteConvParser"; std::unique_ptr attr(new schema::Conv2DT()); - const auto &tfliteAttr = tfliteOp->builtin_options.AsConv2DOptions(); - if (tfliteAttr == nullptr) { + const auto &tflite_attr = tflite_op->builtin_options.AsConv2DOptions(); + if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } attr->group = 1; - attr->strideW = tfliteAttr->stride_w; - attr->strideH = tfliteAttr->stride_h; - attr->dilateH = tfliteAttr->dilation_h_factor; - attr->dilateW = tfliteAttr->dilation_w_factor; - attr->padMode = GetPadMode(tfliteAttr->padding); + attr->strideW = tflite_attr->stride_w; + attr->strideH = tflite_attr->stride_h; + attr->dilateH = tflite_attr->dilation_h_factor; + attr->dilateW = tflite_attr->dilation_w_factor; + attr->padMode = GetPadMode(tflite_attr->padding); attr->format = schema::Format_NHWC; - attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); + attr->hasBias = true; // get the conv op weight tensor - auto weight_index = tfliteOp->inputs[1]; - const auto &weight_tensor = tfliteTensors[weight_index]; + auto weight_index = tflite_op->inputs[1]; + const auto &weight_tensor = tflite_tensors[weight_index]; if (weight_tensor == nullptr) { - MS_LOG(ERROR) << "weight_tensor is null"; + MS_LOG(ERROR) << "the weight tensor is null"; return RET_NULL_PTR; } - std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { - MS_LOG(ERROR) << "parse weight failed"; - return RET_ERROR; - } auto weight_shape = weight_tensor->shape; - attr->channelIn = weight_shape[KHWC_C]; - attr->channelOut = weight_shape[KHWC_K]; - attr->kernelW = weight_shape[KHWC_W]; - attr->kernelH = weight_shape[KHWC_H]; - - // get the conv op bias tensor - if (tfliteOp->inputs.size() == 3) { - attr->hasBias = true; - auto bias_index = tfliteOp->inputs[2]; - const auto &bias_tensor = tfliteTensors[bias_index]; - if (bias_tensor == nullptr) { - MS_LOG(ERROR) << "bias_tensor is null"; - return RET_NULL_PTR; - } - std::vector bias_tensors{bias_tensor.get()}; - if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse bias failed"; - return RET_ERROR; - } - } + attr->channelIn = weight_shape[3]; + attr->channelOut = weight_shape[0]; + attr->kernelH = weight_shape[1]; + attr->kernelW = weight_shape[2]; // calculate pad params + auto data_index = tflite_op->inputs[0]; + const auto &data_tensor = tflite_tensors[data_index]; + std::vector params; + if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, + attr->strideW, attr->kernelH, attr->kernelW, ¶ms) != RET_OK) { + MS_LOG(ERROR) << "get padding params failed"; + return RET_ERROR; + } else { + attr->padUp = params.at(0); + attr->padDown = params.at(1); + attr->padLeft = params.at(2); + attr->padRight = params.at(3); + } op->primitive->value.type = schema::PrimitiveType_Conv2D; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h index d2f523a0c39..abb5d889f75 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_CONV_PARSER_H -#define PREDICT_TFLITE_CONV_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteConvParser : public TfliteNodeParser { public: TfliteConvParser() : TfliteNodeParser("Conv2D") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_CONV_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h index 92695023307..d510d74a163 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h @@ -19,6 +19,7 @@ #include #include +#include #include "tools/converter/converter.h" #include "tools/converter/parser/tflite/tflite_model_parser.h" #include "tools/converter/graphdef_transform.h" diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc index 61b2e3baf6b..f994f904799 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc @@ -17,14 +17,19 @@ #include "tools/converter/parser/tflite/tflite_deconv_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,11 +40,10 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; std::unique_ptr attr(new schema::DeConv2DT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsTransposeConvOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions(); if (tflite_attr == nullptr) { - MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } @@ -50,26 +54,48 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr &tflit attr->dilateW = 1; attr->padMode = GetPadMode(tflite_attr->padding); attr->format = schema::Format_NHWC; + attr->activationType = schema::ActivationType_NO_ACTIVATION; + attr->hasBias = true; // get the conv op weight tensor - auto weight_index = tfliteOp->inputs[1]; - const auto &weight_tensor = tfliteTensors[weight_index]; + auto weight_index = tflite_op->inputs[1]; + const auto &weight_tensor = tflite_tensors[weight_index]; if (weight_tensor == nullptr) { - MS_LOG(ERROR) << "weight_tensor is null"; + MS_LOG(ERROR) << "the weight tensor is null"; return RET_NULL_PTR; } - std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { - return RET_ERROR; - } auto weight_shape = weight_tensor->shape; - attr->channelIn = weight_shape[CHWK_K]; - attr->channelOut = weight_shape[CHWK_C]; - attr->kernelW = weight_shape[CHWK_W]; - attr->kernelH = weight_shape[CHWK_H]; + attr->channelIn = weight_shape[3]; + attr->channelOut = weight_shape[0]; + attr->kernelH = weight_shape[1]; + attr->kernelW = weight_shape[2]; + + // calculate pad params + auto data_index = tflite_op->inputs[2]; + const auto &data_tensor = tflite_tensors[data_index]; + std::vector params; + if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, + attr->strideW, attr->kernelH, attr->kernelW, ¶ms) != RET_OK) { + MS_LOG(ERROR) << "get padding params failed"; + return RET_ERROR; + } else { + attr->padUp = params.at(0); + attr->padDown = params.at(1); + attr->padLeft = params.at(2); + attr->padRight = params.at(3); + } op->primitive->value.type = schema::PrimitiveType_DeConv2D; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h index 46e7e1b8b64..0a26ceb68f3 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_DECONV_PARSER_H -#define PREDICT_TFLITE_DECONV_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteDeConvParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_op_set, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_DECONV_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DECONV_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc index 04586855f56..b8fceaed9a7 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc @@ -18,14 +18,19 @@ #include "tools/converter/parser/tflite/tflite_depth_to_space_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -36,20 +41,23 @@ STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser"; std::unique_ptr attr(new schema::DepthToSpaceT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsDepthToSpaceOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsDepthToSpaceOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); return RET_NULL_PTR; } attr->blockSize = tflite_attr->block_size; - attr->format = schema::Format_NHWC; op->primitive->value.type = schema::PrimitiveType_DepthToSpace; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h index 3be9968d8d0..6fac3d3cd11 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H -#define LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_DEPTH_TO_SPACE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTH_TO_SPACE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc index d3187aefc01..13354b61d1a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc @@ -17,65 +17,22 @@ #include "tools/converter/parser/tflite/tflite_depthwise_conv_parser.h" #include #include +#include #include "tools/common/node_util.h" namespace mindspore { namespace lite { -STATUS TfliteDepthwiseConv2DParser::ParseGroupDepthwiseConv(schema::CNodeT *op, - const std::unique_ptr &attr, - const std::unique_ptr &weightTensor, - TensorCache *tensor_cache) { - std::unique_ptr convAttr(new schema::Conv2DT); - convAttr->format = attr->format; - convAttr->channelIn = attr->channelIn; - convAttr->channelOut = attr->channelIn * attr->channelMultiplier; - convAttr->kernelH = attr->kernelH; - convAttr->kernelW = attr->kernelW; - convAttr->strideH = attr->strideH; - convAttr->strideW = attr->strideW; - convAttr->padMode = attr->padMode; - convAttr->padUp = attr->padUp; - convAttr->padDown = attr->padDown; - convAttr->padLeft = attr->padLeft; - convAttr->padRight = attr->padRight; - convAttr->dilateH = attr->dilateH; - convAttr->dilateW = attr->dilateW; - convAttr->hasBias = attr->hasBias; - convAttr->activationType = attr->activationType; - - auto weightTensorIndex = tensor_cache->FindTensor(weightTensor->name); - if (weightTensorIndex >= 0 && weightTensorIndex < tensor_cache->GetCachedTensor().size()) { - auto liteWeightTensor = tensor_cache->GetCachedTensor()[weightTensorIndex]; - if (liteWeightTensor->dataType == TypeId::kNumberTypeUInt8) { - // convert weight format KHWC -> CHWK - auto status = TransFilterFormat(liteWeightTensor, kKHWC2CHWK); - if (status != RET_OK) { - MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; - return RET_ERROR; - } - } - - if (liteWeightTensor->dataType == kNumberTypeFloat32 || liteWeightTensor->dataType == kNumberTypeFloat) { - // convert weight format KHWC -> CHWK - auto status = TransFilterFormat(liteWeightTensor, kKHWC2CHWK); - if (status != RET_OK) { - MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; - return RET_ERROR; - } - } - } - - op->primitive->value.type = schema::PrimitiveType_Conv2D; - op->primitive->value.value = convAttr.release(); - return RET_OK; -} STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -86,7 +43,6 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr attr(new schema::DepthwiseConv2DT()); const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions(); if (tflite_attr == nullptr) { @@ -100,15 +56,20 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptrpadMode = GetPadMode(tflite_attr->padding); attr->format = schema::Format_NHWC; attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); - // get the conv op weight tensor - auto input_index = tflite_op->inputs[0]; - const auto &input_tenosr = tflite_tensors[input_index]; - if (input_tenosr == nullptr) { - MS_LOG(ERROR) << "the first input is null"; + attr->hasBias = true; + attr->channelMultiplier = tflite_attr->depth_multiplier; + + // get the data tensor + auto data_index = tflite_op->inputs[1]; + const auto &data_tensor = tflite_tensors[data_index]; + if (data_tensor == nullptr) { + MS_LOG(ERROR) << "the data tensor is null"; return RET_NULL_PTR; } - auto input_shape = input_tenosr->shape; + auto data_shape = data_tensor->shape; + attr->channelIn = data_shape[3]; + // get the weight tensor auto weight_index = tflite_op->inputs[1]; const auto &weight_tensor = tflite_tensors[weight_index]; if (weight_tensor == nullptr) { @@ -116,38 +77,33 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptrshape; - attr->channelIn = input_shape[KHWC_C]; - attr->channelMultiplier = tflite_attr->depth_multiplier; - attr->kernelH = weight_shape[KHWC_H]; - attr->kernelW = weight_shape[KHWC_W]; + attr->kernelH = weight_shape[1]; + attr->kernelW = weight_shape[2]; - std::vector weight_tensors{weight_tensor.get()}; - - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { - MS_LOG(ERROR) << "parse weight failed"; + // calculate pad params + std::vector params; + if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, + attr->kernelH, attr->kernelW, ¶ms) != RET_OK) { + MS_LOG(ERROR) << "get padding params failed"; return RET_ERROR; - } - - if (tflite_op->inputs.size() == 3) { - attr->hasBias = true; - auto bias_index = tflite_op->inputs[2]; - const auto &bias_tensor = tflite_tensors[bias_index]; - std::vector bias_tensors{bias_tensor.get()}; - if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse bias failed"; - return RET_ERROR; - } - } - - if (attr->channelMultiplier > 1) { - if (RET_OK != ParseGroupDepthwiseConv(op, attr, weight_tensor, tensor_cache)) { - MS_LOG(ERROR) << "Parse Group DepthwiseConv failed"; - return RET_ERROR; - } } else { - op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; - op->primitive->value.value = attr.release(); + attr->padUp = params.at(0); + attr->padDown = params.at(1); + attr->padLeft = params.at(2); + attr->padRight = params.at(3); } + + op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h index 2e0b1a0d021..6e4022f4fb1 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H -#define PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTHWISE_CONV_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEPTHWISE_CONV_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,20 +29,16 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser { public: TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) override; - - private: - STATUS ParseGroupDepthwiseConv(schema::CNodeT *op, - const std::unique_ptr &attr, - const std::unique_ptr &weightTensor, - TensorCache *tensor_cache); + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_CONV_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc index 74bde414e52..ab0ac6d906c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc @@ -16,15 +16,20 @@ #include "tools/converter/parser/tflite/tflite_dequantize_parser.h" #include #include +#include #include "tools/common/node_util.h" namespace mindspore { namespace lite { -STATUS TfliteDequantizeParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteDequantizeParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteDequantizeNParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,32 +40,30 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr &t return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteDequantizeNParser"; std::unique_ptr attr(new schema::CastT); // get the dequantize input tensor - const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]]; + const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]]; if (in_tensor == nullptr) { - MS_LOG(ERROR) << "weight_tensor is null"; + MS_LOG(ERROR) << "input tensor is null"; return RET_NULL_PTR; } - attr->srcT = dtype_map[in_tensor->type]; - - const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]]; + attr->srcT = GetTfliteDataType(in_tensor->type); + const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]]; if (out_tensor == nullptr) { - MS_LOG(ERROR) << "tensor is null"; + MS_LOG(ERROR) << "output tensor is null"; return RET_NULL_PTR; } - attr->dstT = dtype_map[out_tensor->type]; - std::vector weight_tensors{in_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { - MS_LOG(ERROR) << "parse weight failed"; - return RET_ERROR; - } + attr->dstT = GetTfliteDataType(out_tensor->type); op->primitive->value.type = schema::PrimitiveType_Fp16Cast; op->primitive->value.value = attr.release(); - return 0; + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + return RET_OK; } TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h index 3d6e521d7da..276bd3e748e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h @@ -13,11 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef LITE_TFLITE_DEQUANTIZE_PARSER_H -#define LITE_TFLITE_DEQUANTIZE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -27,13 +28,15 @@ class TfliteDequantizeParser : public TfliteNodeParser { public: TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_DEQUANTIZE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_DEQUANTIZE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc index 624a0e51933..0d3c91528f9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc @@ -17,16 +17,17 @@ #include "tools/converter/parser/tflite/tflite_expand_dims_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { +STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -40,7 +41,7 @@ STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr &t MS_LOG(DEBUG) << "parse TfliteExpandDimsParser"; std::unique_ptr attr(new schema::ExpandDimsT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsExpandDimsOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsExpandDimsOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h index aa867bc3159..cdbda4b5b8c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_EXPAND_DIMS_PARSER_H -#define PREDICT_TFLITE_EXPAND_DIMS_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteExpandDimsParser : public TfliteNodeParser { public: TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_EXPAND_DIMS_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_EXPAND_DIMS_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc deleted file mode 100644 index fa0e90ae11a..00000000000 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/converter/parser/tflite/tflite_fakequant_parser.h" -#include -#include - -namespace mindspore { -namespace lite { -STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; - std::unique_ptr attr(new schema::FullConnectionT()); - - auto weight_index = tfliteOp->inputs[1]; - const auto &weight_tensor = tfliteTensors[weight_index]; - if (weight_tensor == nullptr) { - MS_LOG(ERROR) << "weight_tensor is null"; - return RET_NULL_PTR; - } - std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { - MS_LOG(ERROR) << "parse weight failed"; - return RET_ERROR; - } - - if (tfliteOp->inputs.size() == 3) { - attr->hasBias = true; - auto bias_index = tfliteOp->inputs[2]; - const auto &bias_tensor = tfliteTensors[bias_index]; - if (bias_tensor == nullptr) { - MS_LOG(ERROR) << "bias_tensor is null"; - return RET_NULL_PTR; - } - std::vector bias_tensors{bias_tensor.get()}; - if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse bias failed"; - return RET_ERROR; - } - } - attr->axis = 1; - - op->primitive->value.type = schema::PrimitiveType_FullConnection; - op->primitive->value.value = attr.release(); - return RET_OK; -} - -TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFakeQuantParser()); -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h deleted file mode 100644 index 101c6cfec13..00000000000 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef LITE_TFLITE_FAKEQUANT_PARSER_H -#define LITE_TFLITE_FAKEQUANT_PARSER_H - -#include -#include -#include "tools/converter/parser/tflite/tflite_node_parser.h" -#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" - -namespace mindspore { -namespace lite { -class TfliteFakeQuantParser : public TfliteNodeParser { - public: - TfliteFakeQuantParser() : TfliteNodeParser("FakeQuant") {} - - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; -}; -} // namespace lite -} // namespace mindspore - -#endif // LITE_TFLITE_FAKEQUANT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc index a7b3be7f66d..70405dbf1e1 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc @@ -14,19 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_fill_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_fill_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteFillParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { +STATUS TfliteFillParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteFillParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,18 +40,22 @@ STATUS TfliteFillParser::Parse(const std::unique_ptr &tfliteO return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteFillParser"; std::unique_ptr attr(new schema::FillT()); - if (tfliteOp->inputs.size() > 1) { - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->dims)) { - MS_LOG(ERROR) << "get Fill -> dims failed"; + if (tflite_op->inputs.size() > 1) { + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dims)) { + MS_LOG(ERROR) << "get fill -> dims failed"; return RET_ERROR; } } op->primitive->value.type = schema::PrimitiveType_Fill; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h index 5d8fdee06d8..7bfcd5df999 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_FILL_PARSER_H -#define PREDICT_TFLITE_FILL_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_FILL_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_FILL_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteFillParser : public TfliteNodeParser { public: TfliteFillParser() : TfliteNodeParser("Fill") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_FILL_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_FILL_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc index 0084ad8c7ee..ba02dc35ef9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc @@ -14,17 +14,21 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_fullyconnected_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_fullyconnected_parser.h" +#include +#include namespace mindspore { namespace lite { -STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,51 +39,43 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr node_name_str; + Split(op->name, &node_name_str, "-"); + const char *node_name = node_name_str.data()->c_str(); + if (std::strcmp(node_name, "FullyConnected") == 0) { + MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; + } else if (std::strcmp(node_name, "FakeQuant") == 0) { + MS_LOG(DEBUG) << "parse TfliteFakeQuantParser"; + } std::unique_ptr attr(new schema::FullConnectionT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsFullyConnectedOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsFullyConnectedOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; - return RET_NULL_PTR; - } - attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); - - auto weight_index = tfliteOp->inputs[1]; - const auto &weight_tensor = tfliteTensors[weight_index]; - if (weight_tensor == nullptr) { - MS_LOG(ERROR) << "weight_tensor is null"; return RET_NULL_PTR; } - std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { - MS_LOG(ERROR) << "parse weight failed"; - return RET_ERROR; - } - if (tfliteOp->inputs.size() == 3) { - attr->hasBias = true; - auto bias_index = tfliteOp->inputs[2]; - const auto &bias_tensor = tfliteTensors[bias_index]; - if (bias_tensor == nullptr) { - MS_LOG(ERROR) << "bias_tensor is null"; - return RET_NULL_PTR; - } - std::vector bias_tensors{bias_tensor.get()}; - if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse bias failed"; - return RET_ERROR; - } - } + attr->hasBias = true; attr->axis = 1; attr->useAxis = false; + attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); op->primitive->value.type = schema::PrimitiveType_FullConnection; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } TfliteNodeRegister g_tfliteFullyConnectedParser("FullyConnected", new TfliteFullyConnectedParser()); +TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFakeQuantParser());; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h index f41ab2e3c03..21fd8186ad6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_ADD_PARSER_H -#define PREDICT_TFLITE_ADD_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_FULLY_CONNECTED_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_FULLY_CONNECTED_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,21 @@ class TfliteFullyConnectedParser : public TfliteNodeParser { public: TfliteFullyConnectedParser() : TfliteNodeParser("FullyConnected") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; +}; + +class TfliteFakeQuantParser : public TfliteFullyConnectedParser { + public: + TfliteFakeQuantParser() : TfliteFullyConnectedParser() {} }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_ADD_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_FULLY_CONNECTED_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc index a954baef012..e8218bb4318 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc @@ -17,16 +17,19 @@ #include "tools/converter/parser/tflite/tflite_gather_nd_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteGatherNdParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { +STATUS TfliteGatherNdParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteGatherNdParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,37 +40,18 @@ STATUS TfliteGatherNdParser::Parse(const std::unique_ptr &tfl return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteGatherNdParser"; std::unique_ptr attr(new schema::GatherNdT()); - - if (tfliteOp->inputs.size() != 2) { - MS_LOG(ERROR) << "The input size of gather_nd should be 2"; - return RET_ERROR; - } - - auto y_index = tfliteOp->inputs[1]; - const auto &y_tensor = tfliteTensors[y_index]; - if (y_tensor == nullptr) { - MS_LOG(ERROR) << "the second input is null"; - return RET_NULL_PTR; - } - auto &y_data = tfliteModelBuffer.at(y_tensor->buffer); - if (y_data == nullptr) { - MS_LOG(ERROR) << "the data of the second input is null"; - return RET_NULL_PTR; - } - if (!y_data->data.empty()) { - std::vector y_tensors{y_tensor.get()}; - if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse the second tensor failed"; - return RET_ERROR; - } - } - attr->batchDims = 0; op->primitive->value.type = schema::PrimitiveType_GatherNd; op->primitive->value.value = attr.release(); + + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h index c79d8aa753b..4d9c3e525c2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_GATHER_ND_PARSER_H -#define PREDICT_TFLITE_GATHER_ND_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_GATHER_ND_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_GATHER_ND_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteGatherNdParser : public TfliteNodeParser { public: TfliteGatherNdParser() : TfliteNodeParser("GatherND") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_GATHER_ND_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_GATHER_ND_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc index e49f95d5116..cf388e94d89 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc @@ -17,16 +17,20 @@ #include "tools/converter/parser/tflite/tflite_gather_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteGatherParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { +STATUS TfliteGatherParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteGatherParser"; + + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,39 +41,25 @@ STATUS TfliteGatherParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteGatherParser"; std::unique_ptr attr(new schema::GatherT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsGatherOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsGatherOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } attr->axis = tflite_attr->axis; - attr->batchDims = 0; - auto y_index = tfliteOp->inputs[1]; - const auto &y_tensor = tfliteTensors[y_index]; - if (y_tensor == nullptr) { - MS_LOG(ERROR) << "the second input is null"; - return RET_NULL_PTR; - } - auto &y_data = tfliteModelBuffer.at(y_tensor->buffer); - if (y_data == nullptr) { - MS_LOG(ERROR) << "the data of the second input is null"; - return RET_NULL_PTR; - } - if (!y_data->data.empty()) { - std::vector y_tensors{y_tensor.get()}; - if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse the second tensor failed"; - return RET_ERROR; - } - } - op->primitive->value.type = schema::PrimitiveType_Gather; op->primitive->value.value = attr.release(); + + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h index 5dd842414ae..08ea38976ca 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_GATHER_PARSER_H -#define PREDICT_TFLITE_GATHER_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_GATHER_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_GATHER_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteGatherParser : public TfliteNodeParser { public: TfliteGatherParser() : TfliteNodeParser("Gather") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_GATHER_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_GATHER_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.cc deleted file mode 100644 index e7ff131f73f..00000000000 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "tools/converter/parser/tflite/tflite_hard_swish_parser.h" - -namespace mindspore { -namespace lite { -STATUS TfliteHardSwishParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - - MS_LOG(INFO) << "parse TfliteHardSwishParser"; - std::unique_ptr attr(new schema::ActivationT()); - - attr->type = schema::ActivationType_HSWISH; - - op->primitive->value.type = schema::PrimitiveType_Activation; - op->primitive->value.value = attr.release(); - return RET_OK; -} - -TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser()); -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.h deleted file mode 100644 index 00de1d24583..00000000000 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_hard_swish_parser.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PREDICT_TFLITE_HARD_SWISH_PARSER_H -#define PREDICT_TFLITE_HARD_SWISH_PARSER_H - -#include -#include -#include "tools/converter/parser/tflite/tflite_node_parser.h" -#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" - -namespace mindspore { -namespace lite { -class TfliteHardSwishParser : public TfliteNodeParser { - public: - TfliteHardSwishParser() : TfliteNodeParser("HardSwish") {} - - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; -}; -} // namespace lite -} // namespace mindspore - -#endif // PREDICT_TFLITE_HARD_SWISH_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc index 078523ff400..5f4a4d18353 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc @@ -14,18 +14,21 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_logical_parser.h" #include #include #include -#include "tools/converter/parser/tflite/tflite_logical_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteLogicalParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteLogicalParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,7 +40,7 @@ STATUS TfliteLogicalParser::Parse(const std::unique_ptr &tfli } std::vector node_name_str; - Split(op->name.data(), &node_name_str, "-"); + Split(op->name, &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); if (std::strcmp(node_name, "LogicalAnd") == 0) { MS_LOG(DEBUG) << "parse TfliteLogicalAndParser"; @@ -45,21 +48,24 @@ STATUS TfliteLogicalParser::Parse(const std::unique_ptr &tfli op->primitive->value.type = schema::PrimitiveType_LogicalAnd; op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "LogicalNot") == 0) { - MS_LOG(INFO) << "parse TfliteLogicalNotParser"; + MS_LOG(DEBUG) << "parse TfliteLogicalNotParser"; std::unique_ptr attr(new schema::LogicalNotT()); op->primitive->value.type = schema::PrimitiveType_LogicalNot; op->primitive->value.value = attr.release(); } else if (std::strcmp(node_name, "LogicalOr") == 0) { - MS_LOG(INFO) << "parse TfliteLogicalOrParser"; + MS_LOG(DEBUG) << "parse TfliteLogicalOrParser"; std::unique_ptr attr(new schema::LogicalOrT()); op->primitive->value.type = schema::PrimitiveType_LogicalOr; op->primitive->value.value = attr.release(); - } else { - MS_LOG(ERROR) << "wrong logical type"; - return RET_ERROR; } -return RET_OK; + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + return RET_OK; } TfliteNodeRegister g_TfliteLogicalAndParser("LogicalAnd", new TfliteLogicalAndParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h index 3608f1f12d5..a56de847e6c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_LOGICAL_AND_PARSER_H -#define PREDICT_TFLITE_LOGICAL_AND_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_AND_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_AND_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -29,12 +30,13 @@ class TfliteLogicalParser : public TfliteNodeParser { public: TfliteLogicalParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteLogicalAndParser : public TfliteLogicalParser { @@ -54,4 +56,4 @@ class TfliteLogicalOrParser : public TfliteLogicalParser { } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_LOGICAL_AND_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LOGICAL_AND_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc index 7d37779377f..8986729d72e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc @@ -17,16 +17,20 @@ #include "tools/converter/parser/tflite/tflite_lrn_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteLRNParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, +STATUS TfliteLRNParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteLRNParser"; + + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,10 +41,9 @@ STATUS TfliteLRNParser::Parse(const std::unique_ptr &tfliteOp return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteLRNParser"; std::unique_ptr attr(new schema::LocalResponseNormalizationT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsLocalResponseNormalizationOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsLocalResponseNormalizationOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; @@ -52,6 +55,11 @@ STATUS TfliteLRNParser::Parse(const std::unique_ptr &tfliteOp op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h index b7eae4f9781..a64179c64bb 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_LRN_PARSER_H -#define PREDICT_TFLITE_ADD_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LRN_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LRN_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteLRNParser : public TfliteNodeParser { public: TfliteLRNParser() : TfliteNodeParser("LocalResponseNorm") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_LRN_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LRN_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 7ea6fecfd3d..b9990d0ace7 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -17,6 +17,7 @@ #include "tools/converter/parser/tflite/tflite_model_parser.h" #include #include +#include #include "tools/common/graph_util.h" #include "tools/common/storage.h" #include "flatbuffers/flatbuffers.h" @@ -24,43 +25,45 @@ namespace mindspore { namespace lite { -TfliteModelParser::TfliteModelParser() {} +TfliteModelParser::TfliteModelParser() = default; -TfliteModelParser::~TfliteModelParser() {} +TfliteModelParser::~TfliteModelParser() = default; -std::unique_ptr TfliteModelParser::ReadTfliteModelFromFlat(const char *model_path) { +std::unique_ptr TfliteModelParser::ReadTfliteModel(const char *model_path) { size_t size; auto buf = ReadFile(model_path, &size); if (buf == nullptr) { MS_LOG(ERROR) << "the file buffer is nullptr"; - return nullptr; } flatbuffers::Verifier verify((const uint8_t *)buf, size); if (!tflite::VerifyModelBuffer(verify)) { MS_LOG(ERROR) << "the buffer is invalid and fail to create graph"; - return nullptr; } return tflite::UnPackModel(buf); } -std::string TfliteModelParser::GetTfliteNodeType(const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model) { - auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; - auto msOpType = GetMSOpType(tflite_op_type); - return msOpType; -} - -STATUS TfliteModelParser::SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *sub_graphDef) { - std::vector tensors = tensor_cache.GetCachedTensor(); - for (auto iter : tensors) { - std::unique_ptr temp(iter); - temp->format = schema::Format_NHWC; - sub_graphDef->allTensors.emplace_back(move(temp)); +STATUS TfliteModelParser::CopyConstTensorData(const std::vector> &tflite_model_buffer, + const tflite::TensorT *tflite_tensor, + schema::TensorT *tensor) { + auto count = 1; + std::for_each(tflite_tensor->shape.begin(), tflite_tensor->shape.end(), [&](int32_t sha) { count *= sha; }); + auto data_size = count * GetDataTypeSize(TypeId(tensor->dataType)); + auto buffer_idx = tflite_tensor->buffer; + if (!tflite_model_buffer[buffer_idx]->data.empty()) { + tensor->data.resize(data_size); + if (memcpy_s(tensor->data.data(), data_size, tflite_model_buffer[buffer_idx]->data.data(), data_size)) { + MS_LOG(ERROR) << "memcpy tensor data failed"; + return RET_ERROR; + } + } else { + MS_LOG(ERROR) << "src tensor data is empty"; + return RET_ERROR; } return RET_OK; } -void TfliteModelParser::SetMsTensorFromTflite(const std::unique_ptr &tflite_tensor, - schema::TensorT *tensor) { + +void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr &tflite_tensor, + schema::TensorT *tensor) { std::unique_ptr quant_param(new QuantParamT()); if (!tflite_tensor->quantization->scale.empty()) { quant_param->scale = tflite_tensor->quantization->scale[0]; @@ -87,221 +90,228 @@ void TfliteModelParser::SetMsTensorFromTflite(const std::unique_ptrquantParams.emplace_back(std::move(quant_param)); } -STATUS TfliteModelParser::ParseTfliteQuantParams(const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_op, - schema::CNodeT *op, TensorCache *tensor_cache) { - MS_ASSERT(op->outputIndex.size() == tflite_op->outputs.size()); - for (size_t i = 0; i < tflite_op->inputs.size() && i < op->inputIndex.size(); i++) { - const auto &tflite_tensor = tflite_subgraph->tensors[tflite_op->inputs.at(i)]; - if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && - tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) { - continue; - } - auto &inTensor = tensor_cache->GetCachedTensor().at(op->inputIndex.at(i)); - if (inTensor == nullptr) { - MS_LOG(ERROR) << "Parse tflite quant params inTensor is null"; - return RET_NULL_PTR; - } - SetMsTensorFromTflite(tflite_tensor, inTensor); - } - for (size_t i = 0; i < tflite_op->outputs.size() && i < op->outputIndex.size(); i++) { - const auto &tflite_tensor = tflite_subgraph->tensors[tflite_op->outputs.at(i)]; - if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && - tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) { - continue; - } - auto &outTensor = tensor_cache->GetCachedTensor().at(op->outputIndex.at(i)); - if (outTensor == nullptr) { - MS_LOG(ERROR) << "Parse tflite quant params outTensor is null"; - return RET_NULL_PTR; - } - SetMsTensorFromTflite(tflite_tensor, outTensor); - } - return RET_OK; -} - -STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_op, schema::CNodeT *op, - TensorCache *tensorCache) { - for (const auto &index : tflite_op->outputs) { - const auto &tflite_tensor = tflite_subgraph->tensors[index]; - if (tflite_tensor == nullptr) { - MS_LOG(ERROR) << "tensor with id = " << index << " is null"; - return RET_ERROR; - } - std::unique_ptr tensor(new schema::TensorT()); - tensor->dataType = GetTfliteDataType(tflite_tensor->type); - // change dataType to int8 to fit ms-lite op - if (tensor->dataType == TypeId::kNumberTypeUInt8) { - tensor->dataType = TypeId::kNumberTypeInt8; - } - tensor->dims = tflite_tensor->shape; - tensor->nodeType = schema::NodeType_Parameter; - auto opOutputIndex = tensorCache->AddTensor(tflite_tensor->name, tensor.release(), OP_OUTPUT); - op->outputIndex.emplace_back(opOutputIndex); - } - return RET_OK; -} - -STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr &tflite_model, - const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_op, schema::CNodeT *op, - TensorCache *tensor_cache) { - auto op_type = GetTfliteNodeType(tflite_op, tflite_model); - std::vector op_inputs(tflite_op->inputs); - if (op_type == "DeConv2D") { - reverse(op_inputs.begin(), op_inputs.end()); - } - - for (const auto &tflite_index : op_inputs) { - const auto &tflite_tensor = tflite_subgraph->tensors[tflite_index]; - if (tflite_tensor == nullptr) { - MS_LOG(ERROR) << "tensor with id = " << tflite_index << " is null"; - return RET_ERROR; - } - auto tensor_name = tflite_tensor->name; - unsigned int index = tensor_cache->FindTensor(tensor_name); - if (index != -1) { - op->inputIndex.push_back(index); - } - } - - return RET_OK; -} - -STATUS TfliteModelParser::ParseOp(const std::unique_ptr &tflite_model, - const std::unique_ptr &tflite_subgraph, - schema::MetaGraphT *subGraph, mindspore::lite::TensorCache *tensorCache, - const QuantType &quantType) { - auto i = 0; +STATUS TfliteModelParser::ConvertOp(const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, + const QuantType &quant_type, + schema::MetaGraphT* sub_graph) { + int idx = 0; for (const auto &tflite_op : tflite_subgraph->operators) { - auto opType = GetTfliteNodeType(tflite_op, tflite_model); + auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; + auto op_type = GetMSOpType(tflite_op_type); std::unique_ptr op(new schema::CNodeT); - op->name = opType + "-" + std::to_string(i++); - op->quantType = quantType; + op->name = op_type + "-" + std::to_string(idx++); + op->quantType = quant_type; MS_LOG(INFO) << "parse op: " << op->name.c_str(); - auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType); + // parse tflite op + auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type); if (node_parser == nullptr) { - MS_LOG(ERROR) << "cannot find node parser, opType: " << opType.c_str(); - continue; - // return RET_NULL_PTR; + MS_LOG(ERROR) << "cannot find node parser, opType: " << op_type.c_str(); + return RET_NULL_PTR; } - - auto status = node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, - tflite_model->operator_codes, op.get(), tensorCache, false); - if (status != RET_OK) { - MS_LOG(ERROR) << "node " << opType.c_str() << " parser failed"; + if (node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, op.get(), &tensorsId, + &tensorsFormat, &tensorsIdMap) != RET_OK) { + MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; return RET_ERROR; } - status = SetOpOutputIdx(tflite_subgraph, tflite_op, op.get(), tensorCache); - if (status != RET_OK) { - MS_LOG(ERROR) << "set op " << opType.c_str() << " output index failed"; - return RET_ERROR; - } - - status = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, op.get(), tensorCache); - if (status != RET_OK) { - MS_LOG(ERROR) << "set op " << opType.c_str() << " input index failed"; - return RET_ERROR; - } - - if (quantType != schema::QuantType_QUANT_NONE) { - status = ParseTfliteQuantParams(tflite_subgraph, tflite_op, op.get(), tensorCache); - if (status != RET_OK) { - MS_LOG(ERROR) << "parse op " << opType.c_str() << " quant parameters failed"; - return RET_ERROR; - } - } - - subGraph->nodes.emplace_back(std::move(op)); - opMap[subGraph->nodes.back()->name] = subGraph->nodes.back().get(); - tfliteOpMap[tflite_op.get()] = subGraph->nodes.back().get(); + // add + sub_graph->nodes.emplace_back(op.release()); + opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get(); + tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get(); } return RET_OK; } -void TfliteModelParser::SetInputTensor(const std::unique_ptr &tflite_subgraph, - TensorCache *tensor_cache) { - for (const auto &index : tflite_subgraph->inputs) { - const auto &tflite_tensor = tflite_subgraph->tensors[index]; +STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr &tflite_subgraph, + const std::vector> &tflite_model_buffer, + schema::MetaGraphT* sub_graph) { + for (int i = 0; i < tensorsId.size(); i++) { + auto idx = tensorsId[i]; + if (idx < 0) { + idx += tflite_subgraph->tensors.size(); + } + const auto &tflite_tensor = tflite_subgraph->tensors[idx]; std::unique_ptr tensor(new schema::TensorT()); - tensor->format = schema::Format_NHWC; - tensor->dataType = GetTfliteDataType(tflite_tensor->type) != TypeId::kNumberTypeUInt8 - ? GetTfliteDataType(tflite_tensor->type) - : TypeId::kNumberTypeInt8; - tensor->nodeType = schema::NodeType_Parameter; + + tensor->format = tensorsFormat[i]; + tensor->dataType = GetTfliteDataType(tflite_tensor->type); tensor->dims = tflite_tensor->shape; - tensor_cache->AddTensor(tflite_tensor->name, tensor.release(), GRAPH_INPUT); - } -} -void TfliteModelParser::SetGraphTensorIndex(const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_model, - const mindspore::lite::TensorCache &tensorCache, - schema::MetaGraphT *subGraphDef) { - auto graphInputs = tensorCache.GetGraphInputs(); - subGraphDef->inputIndex.assign(graphInputs.begin(), graphInputs.end()); - - for (auto outputIndex : tflite_subgraph->outputs) { - int i = 0; - bool found = false; - for (const auto &tfliteOp : tflite_subgraph->operators) { - int j = 0; - auto opType = GetTfliteNodeType(tfliteOp, tflite_model); - std::string opName = opType + "-" + std::to_string(i++); - for (auto opOutputIndex : tfliteOp->outputs) { - if (outputIndex == opOutputIndex) { - subGraphDef->outputIndex.emplace_back(opMap[opName]->outputIndex[j]); - found = true; - break; - } - j++; - } - if (found) { + // if graph input tensor + bool isInput = false; + auto tflite_inputs = tflite_subgraph->inputs; + for (int tflite_input : tflite_inputs) { + if (idx == tflite_input) { + isInput = true; break; } } + + // add data for const tensor + auto &tensor_buffer = tflite_model_buffer.at(tflite_tensor->buffer); + auto isConst = (!tensor_buffer->data.empty()); + if (isConst) { + CopyConstTensorData(tflite_model_buffer, tflite_tensor.get(), tensor.get()); + } + + // set tensor attr + if (isInput || isConst) { + tensor->nodeType = schema::NodeType_ValueNode; + } else { + tensor->nodeType = schema::NodeType_Parameter; + } + + // quant param + if (!(tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && + tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty())) { + SetTensorQuantParam(tflite_tensor, tensor.get()); + } + + tensors.push_back(tensor.release()); } + + for (auto iter : tensors) { + std::unique_ptr temp(iter); + sub_graph->allTensors.emplace_back(move(temp)); + } + return RET_OK; } -MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::string &weightFile, - const QuantType &quantType) { - if (ValidateFileStr(modelFile, ".tflite") != RET_OK) { - MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.tflite"; - return nullptr; - } +STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr &tflite_subgraph, + schema::MetaGraphT* sub_graph) { + int id; - std::unique_ptr tflite_model(new tflite::ModelT()); - tflite_model = ReadTfliteModelFromFlat(modelFile.c_str()); - if (tflite_model == nullptr) { - MS_LOG(ERROR) << "read tflite model failed"; - return nullptr; + // graph input + std::vector graph_inputs; + for (int i = 0; i < tflite_subgraph->inputs.size(); i++) { + const int idx = tflite_subgraph->inputs[i]; + if (idx < 0) { + id = idx + tflite_subgraph->tensors.size(); + } else { + id = idx; + } + auto iter = tensorsIdMap.find(id); + if (iter != tensorsIdMap.end()) { + graph_inputs.push_back(iter->second); + } } + sub_graph->inputIndex.assign(graph_inputs.begin(), graph_inputs.end()); + + // graph output + std::vector graph_outputs; + for (int i = 0; i < tflite_subgraph->outputs.size(); i++) { + const int idx = tflite_subgraph->outputs[i]; + if (idx < 0) { + id = idx + tflite_subgraph->tensors.size(); + } else { + id = idx; + } + auto iter = tensorsIdMap.find(id); + if (iter != tensorsIdMap.end()) { + graph_outputs.push_back(iter->second); + } + } + sub_graph->outputIndex.assign(graph_outputs.begin(), graph_outputs.end()); + return RET_OK; +} + +STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT* sub_graph) { + for (auto &op : sub_graph->nodes) { + if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + auto attr = op->primitive->value.AsDepthwiseConv2D(); + if (attr->channelMultiplier > 1) { + // update attr + std::unique_ptr conv_attr(new schema::Conv2DT); + conv_attr->group = 0; + conv_attr->format = attr->format; + conv_attr->channelIn = attr->channelIn; + conv_attr->channelOut = attr->channelIn * attr->channelMultiplier; + conv_attr->kernelH = attr->kernelH; + conv_attr->kernelW = attr->kernelW; + conv_attr->strideH = attr->strideH; + conv_attr->strideW = attr->strideW; + conv_attr->padMode = attr->padMode; + conv_attr->padUp = attr->padUp; + conv_attr->padDown = attr->padDown; + conv_attr->padLeft = attr->padLeft; + conv_attr->padRight = attr->padRight; + conv_attr->dilateH = attr->dilateH; + conv_attr->dilateW = attr->dilateW; + conv_attr->hasBias = attr->hasBias; + conv_attr->activationType = attr->activationType; + + op->primitive->value.type = schema::PrimitiveType_Conv2D; + op->primitive->value.value = conv_attr.release(); + + // update weight + auto weight_id = op->inputIndex[1]; + auto &weight_tensor = sub_graph->allTensors.at(weight_id); + if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) { + // convert weight format KHWC -> CHWK + auto status = TransFilterFormat(weight_tensor.get(), kKHWC2CHWK); + if (status != RET_OK) { + MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; + return RET_ERROR; + } + } + if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) { + // convert weight format KHWC -> CHWK + auto status = TransFilterFormat(weight_tensor.get(), kKHWC2CHWK); + if (status != RET_OK) { + MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; + return RET_ERROR; + } + } + } + } + } + return RET_OK; +} + + +MetaGraphT *TfliteModelParser::Parse(const std::string &model_file, + const std::string &weight_file, + const QuantType &quant_type) { + std::unique_ptr sub_graph(new schema::MetaGraphT); + sub_graph->name = "MS_model converted by TF-Lite"; + + // load graph + // std::unique_ptr tflite_model(new tflite::ModelT()); + std::unique_ptr tflite_model = ReadTfliteModel(model_file.c_str()); + if (tflite_model->subgraphs.size() != 1) { MS_LOG(ERROR) << "read tflite model subgraphs failed"; return nullptr; } const auto &tflite_subgraph = tflite_model->subgraphs[0]; - // set dst subGraph input/output tensor - TensorCache tensorCache; - SetInputTensor(tflite_subgraph, &tensorCache); - - // set dst subGraph op attr and tensor_cache. - std::unique_ptr subGraph(new schema::MetaGraphT); - subGraph->name = "MS_model converted by TF-Lite"; - auto status = ParseOp(tflite_model, tflite_subgraph, subGraph.get(), &tensorCache, quantType); - if (status != RET_OK) { - MS_LOG(ERROR) << "ParseOp failed."; + // convert op + if (ConvertOp(tflite_model, tflite_subgraph, quant_type, sub_graph.get()) != RET_OK) { + MS_LOG(ERROR) << "parse op failed."; return nullptr; } - SetGraphTensorIndex(tflite_subgraph, tflite_model, tensorCache, subGraph.get()); - SetAllTensors(tensorCache, subGraph.get()); - return subGraph.release(); + // convert tensor + if (ConvertTensor(tflite_subgraph, tflite_model->buffers, sub_graph.get()) != RET_OK) { + MS_LOG(ERROR) << "convert tensor failed"; + return nullptr; + } + + // set graph input/output + if (GetGraphInfo(tflite_subgraph, sub_graph.get()) != RET_OK) { + MS_LOG(ERROR) << "convert tensors failed"; + return nullptr; + } + + // update for depthwiseConv + if (UpdateOp(sub_graph.get()) != RET_OK) { + MS_LOG(ERROR) << "update depthwise conv failed"; + return nullptr; + } + + return sub_graph.release(); } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index ed861bb866f..4c8eba0728c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -26,6 +26,7 @@ #include #include #include +#include #include "securec/include/securec.h" #include "tools/converter/model_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -38,44 +39,41 @@ class TfliteModelParser : public ModelParser { public: TfliteModelParser(); - virtual ~TfliteModelParser(); + ~TfliteModelParser() override; - MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile, + MetaGraphT *Parse(const std::string &model_file, + const std::string &weight_file, const QuantType &quantType = QuantType_QUANT_NONE) override; private: - std::unique_ptr ReadTfliteModelFromFlat(const char *buf); + std::unique_ptr ReadTfliteModel(const char *model_path); - void SetMsTensorFromTflite(const std::unique_ptr &tflite_tensor, schema::TensorT *tensor); + STATUS CopyConstTensorData(const std::vector> &tflite_model_buffer, + const tflite::TensorT *tflite_tensor, + schema::TensorT *tensor); - void SetInputTensor(const std::unique_ptr &tflite_subgraph, TensorCache *tensor_cache); + void SetTensorQuantParam(const std::unique_ptr &tflite_tensor, + schema::TensorT *tensor); - void SetGraphTensorIndex(const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_model, - const mindspore::lite::TensorCache &tensorCache, - schema::MetaGraphT *subGraphDef); + STATUS ConvertOp(const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, + const QuantType &quant_type, + schema::MetaGraphT* sub_graph); - STATUS ParseOp(const std::unique_ptr &tflite_model, - const std::unique_ptr &tflite_subgraph, schema::MetaGraphT *sub_graph, - TensorCache *tensor_cache, const QuantType &quantType); + STATUS ConvertTensor(const std::unique_ptr &tflite_subgraph, + const std::vector> &tflite_model_buffer, + schema::MetaGraphT* sub_graph); - STATUS ParseTfliteQuantParams(const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_op, schema::CNodeT *op, - TensorCache *tensor_cache); + STATUS GetGraphInfo(const std::unique_ptr &tflite_subgraph, + schema::MetaGraphT* sub_graph); - std::string GetTfliteNodeType(const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model); + STATUS UpdateOp(schema::MetaGraphT* sub_graph); - STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *sub_graph); - - STATUS SetOpOutputIdx(const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_op, schema::CNodeT *op, - TensorCache *tensorCache); - - STATUS SetOpInputIdx(const std::unique_ptr &tflite_model, - const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_op, schema::CNodeT *op, - TensorCache *tensor_cache); + private: + std::vector tensorsId; + std::vector tensorsFormat; + std::map tensorsIdMap; + std::vector tensors; std::map opMap; std::map tfliteOpMap; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc deleted file mode 100644 index cd24097896f..00000000000 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "securec/include/securec.h" -#include "tools/converter/parser/tflite/tflite_node_parser.h" -#include "tools/converter/parser/tflite/tflite_util.h" - -namespace mindspore { -namespace lite { -STATUS TfliteNodeParser::CopyTfliteTensorData(const std::vector> &tfliteModelBuffer, - const tflite::TensorT *tflite_tensor, - schema::TensorT *tensor) { - auto count = 1; - std::for_each(tflite_tensor->shape.begin(), tflite_tensor->shape.end(), [&](int32_t sha) { count *= sha; }); - auto data_size = count * GetDataTypeSize(TypeId(tensor->dataType)); - auto buffer_idx = tflite_tensor->buffer; - if (!tfliteModelBuffer[buffer_idx]->data.empty()) { - tensor->data.resize(data_size); - if (memcpy_s(tensor->data.data(), data_size, tfliteModelBuffer[buffer_idx]->data.data(), data_size)) { - MS_LOG(ERROR) << "memcpy tensor data failed"; - return RET_ERROR; - } - } else { - MS_LOG(ERROR) << "src tensor data is empty"; - return RET_ERROR; - } - return RET_OK; -} - -STATUS TfliteNodeParser::ParseTensor(const std::vector &ts, - const std::vector> &tfliteModelBuffer, - mindspore::lite::TensorCache *tensor_cache, - int node_type, - bool isWeight) { - for (const auto &t : ts) { - auto idx = tensor_cache->FindTensor(t->name); - if (idx < 0) { - std::unique_ptr tensor(new schema::TensorT); - tensor->dataType = GetTfliteDataType(t->type); - tensor->dims = t->shape; - - if (isWeight) { - tensor->format = schema::Format_KHWC; - } else { - tensor->format = schema::Format_NHWC; - } - - if (t->buffer > 0) { - CopyTfliteTensorData(tfliteModelBuffer, t, tensor.get()); - } - - MS_LOG(DEBUG) << "add tensor name: " << t->name.c_str(); - tensor_cache->AddTensor(t->name, tensor.release(), node_type); - } - } - return RET_OK; -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h index de5f0d1b287..c2422e8d900 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h @@ -21,54 +21,84 @@ #include #include #include +#include #include "utils/log_adapter.h" #include "schema/inner/model_generated.h" -#include "tools/converter/parser/tflite/tflite_util.h" #include "tools/converter/parser/tflite/schema_generated.h" #include "tools/common/tensor_util.h" #include "ir/dtype/type_id.h" #include "include/errorcode.h" +#include "tools/converter/parser/tflite/tflite_util.h" namespace mindspore { namespace lite { class TfliteNodeParser { public: - explicit TfliteNodeParser(const std::string &nodeName) : name(nodeName) {} + explicit TfliteNodeParser(const std::string &node_name) : name(node_name) {} virtual ~TfliteNodeParser() = default; - virtual STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, + virtual STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) = 0; + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) = 0; - STATUS ParseTensor(const std::vector &ts, - const std::vector> &tfliteModelBuffer, - mindspore::lite::TensorCache *tensor_cache, - int node_type, - bool isWeight); + void AddOpInput(schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map, + int idx, int new_idx, int total, schema::Format format) { + auto iter = tensors_id_map->find(idx); + if (iter != tensors_id_map->end()) { + op->inputIndex.emplace_back(iter->second); + } else { + if (idx < 0) { + idx += total; + } + tensors_id->emplace_back(idx); + tensors_format->emplace_back(format); + tensors_id_map->insert(std::make_pair(idx, new_idx)); + op->inputIndex.emplace_back(new_idx); + } + } - STATUS CopyTfliteTensorData(const std::vector> &tfliteModelBuffer, - const tflite::TensorT *tflite_tensor, - schema::TensorT *tensor); + void AddOpOutput(schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map, + int idx, int new_idx, int total, schema::Format format) { + auto iter = tensors_id_map->find(idx); + if (iter != tensors_id_map->end()) { + op->outputIndex.emplace_back(iter->second); + } else { + if (idx < 0) { + idx += total; + } + tensors_id->emplace_back(idx); + tensors_format->emplace_back(format); + tensors_id_map->insert(std::make_pair(idx, new_idx)); + op->outputIndex.emplace_back(new_idx); + } + } template - STATUS GetTfliteData(const int32_t tensor_index, const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, + STATUS GetTfliteData(const int32_t tensor_index, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, std::vector &attr_data) { int32_t count = 1; - std::for_each(tfliteTensors[tensor_index]->shape.begin(), tfliteTensors[tensor_index]->shape.end(), + std::for_each(tflite_tensors[tensor_index]->shape.begin(), tflite_tensors[tensor_index]->shape.end(), [&](int32_t sha) { count *= sha; }); - auto &buf_data = tfliteModelBuffer[tfliteTensors[tensor_index]->buffer]; + auto &buf_data = tflite_model_buffer[tflite_tensors[tensor_index]->buffer]; if (buf_data == nullptr) { MS_LOG(ERROR) << "buf_data is null"; return RET_NULL_PTR; } auto data_ptr = buf_data->data.data(); - switch (tfliteTensors[tensor_index]->type) { + switch (tflite_tensors[tensor_index]->type) { case tflite::TensorType_UINT8: { for (int i = 0; i < count; i++) { uint8_t data = *(static_cast(static_cast(data_ptr))); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc index 0da5759332e..4f8b7261611 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc @@ -14,17 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_one_hot_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_one_hot_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteOneHotParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteOneHotParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteOneHotParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,17 +40,15 @@ STATUS TfliteOneHotParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(INFO) << "parse TfliteOneHotParser"; std::unique_ptr attr(new schema::OneHotT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsOneHotOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsOneHotOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; return RET_NULL_PTR; } - auto axis = tflite_attr->axis; - const auto &tensor = tfliteTensors[tfliteOp->inputs[0]]; + const auto &tensor = tflite_tensors[tflite_op->inputs[0]]; if (tensor == nullptr) { MS_LOG(ERROR) << "tensor is null"; return RET_NULL_PTR; @@ -58,6 +61,13 @@ STATUS TfliteOneHotParser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_OneHot; op->primitive->value.value = attr.release(); + + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h index f21659714a1..8a23110957f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_ONE_HOT_PARSER_H -#define PREDICT_TFLITE_ONE_HOT_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ONE_HOT_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ONE_HOT_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,15 @@ class TfliteOneHotParser : public TfliteNodeParser { public: TfliteOneHotParser() : TfliteNodeParser("OneHot") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_ONE_HOT_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ONE_HOT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc index cfd17005478..8f4535b7459 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc @@ -14,17 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_pad_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_pad_parser.h" +#include namespace mindspore { namespace lite { -STATUS TflitePadParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TflitePadParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TflitePadParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,9 +40,9 @@ STATUS TflitePadParser::Parse(const std::unique_ptr &tfliteOp return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TflitePadParser"; std::unique_ptr attr(new schema::PadT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsPadOptions(); + + const auto &tflite_attr = tflite_op->builtin_options.AsPadOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; @@ -45,13 +50,18 @@ STATUS TflitePadParser::Parse(const std::unique_ptr &tfliteOp attr->paddingMode = schema::PaddingMode_CONSTANT; attr->constantValue = 0.0f; - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->paddings)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->paddings)) { MS_LOG(ERROR) << "get pad -> paddings failed"; return RET_ERROR; } op->primitive->value.type = schema::PrimitiveType_Pad; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h index e2f0c29c7b8..44f657ad4e2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_PAD_PARSER_H -#define PREDICT_TFLITE_PAD_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_PAD_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_PAD_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,13 +29,15 @@ class TflitePadParser : public TfliteNodeParser { public: TflitePadParser() : TfliteNodeParser("Pad") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_PAD_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_PAD_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc index f2c51fcd515..1d7db44adb5 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc @@ -14,18 +14,21 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_pooling_parser.h" #include #include #include -#include "tools/converter/parser/tflite/tflite_pooling_parser.h" +#include namespace mindspore { namespace lite { STATUS TflitePoolingParser::Parse(const std::unique_ptr &tflite_op, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -39,7 +42,7 @@ STATUS TflitePoolingParser::Parse(const std::unique_ptr &tfli std::unique_ptr attr(new schema::PoolingT()); std::vector node_name_str; - Split(op->name.data(), &node_name_str, "-"); + Split(op->name, &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); if (std::strcmp(node_name, "MeanPooling") == 0) { MS_LOG(DEBUG) << "parser TfliteMeanPoolingParser"; @@ -47,9 +50,6 @@ STATUS TflitePoolingParser::Parse(const std::unique_ptr &tfli } else if (std::strcmp(node_name, "MaxPooling") == 0) { MS_LOG(DEBUG) << "parse TfliteMaxPoolingParser"; attr->poolingMode = schema::PoolMode_MAX_POOLING; - } else { - MS_LOG(ERROR) << "wrong pooling type"; - return RET_ERROR; } const auto &tflite_attr = tflite_op->builtin_options.AsPool2DOptions(); @@ -64,41 +64,31 @@ STATUS TflitePoolingParser::Parse(const std::unique_ptr &tfli attr->padMode = GetPadMode(tflite_attr->padding); attr->format = schema::Format_NHWC; - // by default attr->global = false; attr->roundMode = schema::RoundMode_FLOOR; // calculate pad params - if (attr->padMode == schema::PadMode_VALID || attr->padMode == schema::PadMode_NOTSET) { - attr->padUp = 0; - attr->padDown = 0; - attr->padLeft = 0; - attr->padRight = 0; - } else if (attr->padMode == schema::PadMode_SAME) { - auto data_index = tflite_op->inputs[0]; - const auto &data_tensor = tfliteTensors[data_index]; - if (data_tensor == nullptr) { - MS_LOG(ERROR) << "the first input is null"; - return RET_NULL_PTR; - } - - auto shape = data_tensor->shape; - int H_input = shape.at(1); - int W_input = shape.at(2); - - int H_output = ceil(H_input / attr->strideH); - int pad_needed_H = (H_output - 1) * attr->strideH + attr->windowH - H_input; - attr->padUp = floor(pad_needed_H / 2.0); - attr->padDown = pad_needed_H - attr->padUp; - - int W_output = ceil(W_input / attr->strideW); - int pad_needed_W = (W_output - 1) * attr->strideW + attr->windowW - W_input; - attr->padLeft = floor(pad_needed_W / 2.0); - attr->padRight = pad_needed_W - attr->padLeft; + auto data_index = tflite_op->inputs[0]; + const auto &data_tensor = tflite_tensors[data_index]; + std::vector params; + if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, + attr->strideW, attr->windowH, attr->windowW, ¶ms) != RET_OK) { + MS_LOG(ERROR) << "get padding params failed"; + return RET_ERROR; + } else { + attr->padUp = params.at(0); + attr->padDown = params.at(1); + attr->padLeft = params.at(2); + attr->padRight = params.at(3); } op->primitive->value.type = schema::PrimitiveType_Pooling; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h index 1129df01cd2..fe8d5fa8041 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_MEAN_POOLING_PARSER_H -#define PREDICT_TFLITE_MEAN_POOLING_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MEAN_POOLING_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MEAN_POOLING_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,12 +29,13 @@ class TflitePoolingParser : public TfliteNodeParser { public: TflitePoolingParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteMeanPoolingParser : public TflitePoolingParser { @@ -48,5 +50,5 @@ class TfliteMaxPoolingParser : public TflitePoolingParser { } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_CONV_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc index 38216a911e8..519b5c243ba 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc @@ -17,16 +17,19 @@ #include "tools/converter/parser/tflite/tflite_range_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteRangeParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { +STATUS TfliteRangeParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteRangeParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,13 +40,20 @@ STATUS TfliteRangeParser::Parse(const std::unique_ptr &tflite return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteRangeParser"; std::unique_ptr attr(new schema::RangeT()); attr->dType = 0; +// attr->start +// attr->limit +// attr->delta op->primitive->value.type = schema::PrimitiveType_Range; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h index 27015901511..204b5c8e73b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_RANGE_PARSER_H -#define PREDICT_TFLITE_RANGE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RANGE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RANGE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteRangeParser : public TfliteNodeParser { public: TfliteRangeParser() : TfliteNodeParser("Range") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_RANGE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RANGE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc index d66b5278dc6..d4fcac5371c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc @@ -17,16 +17,19 @@ #include "tools/converter/parser/tflite/tflite_rank_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteRankParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { +STATUS TfliteRankParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteRankParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,11 +40,15 @@ STATUS TfliteRankParser::Parse(const std::unique_ptr &tfliteO return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteRankParser"; std::unique_ptr attr(new schema::RankT()); op->primitive->value.type = schema::PrimitiveType_Rank; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h index 11257b6f2bb..9afd1fcc22b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_RANK_PARSER_H -#define PREDICT_TFLITE_RANK_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RANK_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RANK_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteRankParser : public TfliteNodeParser { public: TfliteRankParser() : TfliteNodeParser("Rank") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_RANK_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RANK_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc index 4472430e48a..9cfbceed32a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc @@ -14,18 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_reduce_parser.h" #include #include #include -#include "tools/converter/parser/tflite/tflite_reduce_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteReduceParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteReduceParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,8 +41,9 @@ STATUS TfliteReduceParser::Parse(const std::unique_ptr &tflit } std::unique_ptr attr(new schema::ReduceT()); + // auto tflite_tensors = tflite_subgraph->tensors; - const auto &tflite_attr = tfliteOp->builtin_options.AsReducerOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsReducerOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; return RET_NULL_PTR; @@ -46,8 +51,9 @@ STATUS TfliteReduceParser::Parse(const std::unique_ptr &tflit attr->keepDims = tflite_attr->keep_dims; std::vector node_name_str; - Split(op->name.data(), &node_name_str, "-"); + Split(op->name, &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); + if (std::strcmp(node_name, "ReduceMax") == 0) { MS_LOG(DEBUG) << "parse TfliteReduceMaxParser"; attr->mode = schema::ReduceMode_ReduceMax; @@ -67,18 +73,20 @@ STATUS TfliteReduceParser::Parse(const std::unique_ptr &tflit // attr->mode; MS_LOG(ERROR) << "ms-lite haven't supported REDUCE_ANY now"; return RET_NOT_FIND_OP; - } else { - MS_LOG(ERROR) << "wrong reduce type"; - return RET_ERROR; } - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->axes)) { - MS_LOG(ERROR) << "get reduce_prod -> axes failed"; + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->axes)) { + MS_LOG(ERROR) << "get reduce -> axes failed"; return RET_ERROR; } op->primitive->value.type = schema::PrimitiveType_Reduce; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h index 628243881c4..7960143948f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_REDUCE_MAX_PARSER_H -#define PREDICT_TFLITE_REDUCE_MAX_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REDUCE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REDUCE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,12 +29,13 @@ class TfliteReduceParser : public TfliteNodeParser { public: TfliteReduceParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteReduceMaxParser : public TfliteReduceParser { @@ -69,4 +71,4 @@ class TfliteReduceAnyParser : public TfliteReduceParser { } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_REDUCE_MAX_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REDUCE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc index c67b113c376..7bdc2d83184 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc @@ -14,17 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_reshape_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_reshape_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteReshapeParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteReshapeParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteReshapeParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,23 +40,22 @@ STATUS TfliteReshapeParser::Parse(const std::unique_ptr &tfli return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteReshapeParser"; std::unique_ptr attr(new schema::ReshapeT()); - const auto &tfliteAttr = tfliteOp->builtin_options.AsReshapeOptions(); + const auto &tfliteAttr = tflite_op->builtin_options.AsReshapeOptions(); if (tfliteAttr == nullptr) { - if (tfliteOp->inputs.size() < 2) { - MS_LOG(ERROR) << "expected two input tensors, but got: " << tfliteOp->inputs.size(); + if (tflite_op->inputs.size() < 2) { + MS_LOG(ERROR) << "expected two input tensors, but got: " << tflite_op->inputs.size(); return RET_ERROR; } - auto shape_tensor_index = tfliteOp->inputs[1]; - const auto & shape_tensor = tfliteTensors[shape_tensor_index]; + auto shape_tensor_index = tflite_op->inputs[1]; + const auto & shape_tensor = tflite_tensors[shape_tensor_index]; if (shape_tensor == nullptr) { MS_LOG(ERROR) << "shape_tensor is null"; return RET_NULL_PTR; } - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->shape)) { - MS_LOG(ERROR) << "get reshape->shape error"; + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->shape)) { + MS_LOG(ERROR) << "get reshape -> shape failed"; return RET_ERROR; } } else { @@ -64,6 +68,13 @@ STATUS TfliteReshapeParser::Parse(const std::unique_ptr &tfli op->primitive->value.type = schema::PrimitiveType_Reshape; op->primitive->value.value = attr.release(); + + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h index a122d9512fe..6ab5fa1db6b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_RESHAPE_PARSER_H -#define PREDICT_TFLITE_RESHAPE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESHAPE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESHAPE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,16 @@ class TfliteReshapeParser : public TfliteNodeParser { public: TfliteReshapeParser() : TfliteNodeParser("Reshape") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_ADD_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ADD_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc index 4dfc27c32e7..4bb37bb93c4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc @@ -14,18 +14,21 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_resize_parser.h" #include #include #include -#include "tools/converter/parser/tflite/tflite_resize_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteResizeParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteResizeParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -41,10 +44,9 @@ STATUS TfliteResizeParser::Parse(const std::unique_ptr &tflit std::vector node_name_str; Split(op->name.data(), &node_name_str, "-"); const char *node_name = node_name_str.data()->c_str(); - if (std::strcmp(node_name, "ResizeBilinear") == 0) { MS_LOG(DEBUG) << "parse TfliteResizeBilinearParser"; - const auto &tfliteAttr = tfliteOp->builtin_options.AsResizeBilinearOptions(); + const auto &tfliteAttr = tflite_op->builtin_options.AsResizeBilinearOptions(); if (tfliteAttr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; @@ -53,7 +55,7 @@ STATUS TfliteResizeParser::Parse(const std::unique_ptr &tflit attr->method = schema::ResizeMethod_BILINEAR; } else if (std::strcmp(node_name, "NearestNeighbor") == 0) { MS_LOG(DEBUG) << "parse TfliteResizeNearestNeighborParser"; - const auto &tfliteAttr = tfliteOp->builtin_options.AsResizeNearestNeighborOptions(); + const auto &tfliteAttr = tflite_op->builtin_options.AsResizeNearestNeighborOptions(); if (tfliteAttr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; @@ -68,14 +70,14 @@ STATUS TfliteResizeParser::Parse(const std::unique_ptr &tflit attr->format = schema::Format_NHWC; attr->preserveAspectRatio = false; - auto tfliteResizeTensorIndex = tfliteOp->inputs[1]; - const auto & shape_tensor = tfliteTensors[tfliteResizeTensorIndex]; + auto tfliteResizeTensorIndex = tflite_op->inputs[1]; + const auto & shape_tensor = tflite_tensors[tfliteResizeTensorIndex]; if (shape_tensor == nullptr) { MS_LOG(ERROR) << "shape_tensor is null"; return RET_NULL_PTR; } auto resizeTensorBufferIndex = shape_tensor->buffer; - const auto & buff = tfliteModelBuffer.at(resizeTensorBufferIndex); + const auto & buff = tflite_model_buffer.at(resizeTensorBufferIndex); if (buff == nullptr) { MS_LOG(ERROR) << "buff_data is null"; return RET_NULL_PTR; @@ -88,6 +90,11 @@ STATUS TfliteResizeParser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_Resize; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h index 779a1cf0cd9..14245ba3e9e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_RESIZE_PARSER_H -#define PREDICT_TFLITE_RESIZE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESIZE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESIZE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,11 +29,13 @@ class TfliteResizeParser : public TfliteNodeParser { public: TfliteResizeParser() : TfliteNodeParser("node_name") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; class TfliteResizeBilinearParser : public TfliteResizeParser { @@ -48,5 +51,5 @@ class TfliteResizeNearestNeighborParser : public TfliteResizeParser { } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_RESIZE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESIZE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc index a1ddd7f13cf..7e449a3588a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc @@ -17,16 +17,19 @@ #include "tools/converter/parser/tflite/tflite_reverse_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteReverseParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { +STATUS TfliteReverseParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteReverseParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,15 +40,20 @@ STATUS TfliteReverseParser::Parse(const std::unique_ptr &tfli return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteReverseParser"; std::unique_ptr attr(new schema::ReverseT()); - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->axis)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->axis)) { + MS_LOG(ERROR) << "get reverse -> axis failed"; return RET_ERROR; } op->primitive->value.type = schema::PrimitiveType_Reverse; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h index 3965ab1eceb..d9fa0ce2df3 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_REVERSE_PARSER_H -#define PREDICT_TFLITE_REVERSE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REVERSE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REVERSE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,15 +29,16 @@ class TfliteReverseParser : public TfliteNodeParser { public: TfliteReverseParser() : TfliteNodeParser("reverse") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_REVERSE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REVERSE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc index 867ca6bc77c..3d996a4feff 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc @@ -15,17 +15,23 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_reverse_sequence_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_reverse_sequence_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteReverseSequenceParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) { + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteReverseSequenceParser"; + + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -36,7 +42,6 @@ STATUS TfliteReverseSequenceParser::Parse(const std::unique_ptr attr(new schema::ReverseSequenceT()); const auto &tflite_attr = tflite_op->builtin_options.AsReverseSequenceOptions(); @@ -54,6 +59,11 @@ STATUS TfliteReverseSequenceParser::Parse(const std::unique_ptrprimitive->value.type = schema::PrimitiveType_ReverseSequence; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h index 20cac753e11..927247fe86b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_REVERSE_SEQUENCE_PARSER_H -#define LITE_TFLITE_REVERSE_SEQUENCE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REVERSE_SEQUENCE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REVERSE_SEQUENCE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteReverseSequenceParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_REVERSE_SEQUENCE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REVERSE_SEQUENCE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc index f1dac84959c..e86e7c7a758 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc @@ -14,18 +14,23 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_scatter_nd_parser.h" #include #include #include -#include "tools/converter/parser/tflite/tflite_scatter_nd_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteScatterNdParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteScatterNdParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteScatterNdParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -36,33 +41,26 @@ STATUS TfliteScatterNdParser::Parse(const std::unique_ptr &tf return RET_NULL_PTR; } - MS_LOG(INFO) << "parse TfliteScatterNdParser"; std::unique_ptr attr(new schema::ScatterNDT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsScatterNdOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsScatterNdOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; return RET_NULL_PTR; } - /* - MS_LOG(DEBUG) << "op->inputIndex"; - for (auto &i : op->inputIndex) { - MS_LOG(DEBUG) << i; - } - */ - // in tflite, kIndices = 0, kUpdates = 1, kShape = 2 - // in mslite, kScatterShapeIndex = 0, kScatterIndicesIndex = 1, kScatterUpdateIndex = 2; - std::swap(op->inputIndex[0], op->inputIndex[2]); - std::swap(op->inputIndex[1], op->inputIndex[2]); - /* - MS_LOG(DEBUG) << "op->inputIndex after resort"; - for (auto &i : op->inputIndex) { - MS_LOG(DEBUG) << i; - } - */ - op->primitive->value.type = schema::PrimitiveType_ScatterND; op->primitive->value.value = attr.release(); + + // in tflite, kIndices = 0, kUpdates = 1, kShape = 2 + // in mslite, kScatterShapeIndex = 0, kScatterIndicesIndex = 1, kScatterUpdateIndex = 2; + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[2], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h index 38232968855..6baeed21bed 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_SCATTER_ND_PARSER_H -#define PREDICT_TFLITE_SCATTER_ND_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SCATTER_ND_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SCATTER_ND_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,15 @@ class TfliteScatterNdParser : public TfliteNodeParser { public: TfliteScatterNdParser() : TfliteNodeParser("ScatterNd") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_SCATTER_ND_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SCATTER_ND_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc index 95163afde18..eae6cdb3527 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc @@ -14,17 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_shape_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_shape_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteShapeParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteShapeParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteShapeParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,11 +40,15 @@ STATUS TfliteShapeParser::Parse(const std::unique_ptr &tflite return RET_NULL_PTR; } - MS_LOG(INFO) << "parse TfliteShapeParser"; std::unique_ptr attr(new schema::ShapeT()); op->primitive->value.type = schema::PrimitiveType_Shape; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h index b0f0fee85c0..ab34d3c901d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_SHAPE_PARSER_H -#define PREDICT_TFLITE_SHAPE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SHAPE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SHAPE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,15 @@ class TfliteShapeParser : public TfliteNodeParser { public: TfliteShapeParser() : TfliteNodeParser("Shape") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_SHAPE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SHAPE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc index a18e624380d..14d31067235 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc @@ -17,15 +17,20 @@ #include "tools/converter/parser/tflite/tflite_slice_parser.h" #include #include +#include namespace mindspore { namespace lite { -STATUS TfliteSliceParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteSliceParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteSliceParser"; + + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -36,20 +41,26 @@ STATUS TfliteSliceParser::Parse(const std::unique_ptr &tflite return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteSliceParser"; std::unique_ptr attr(new schema::SliceT()); - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->begin)) { + attr->format = schema::Format_NHWC; + + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->begin)) { MS_LOG(ERROR) << "get slice -> begin failed"; return RET_ERROR; } - if (GetTfliteData(tfliteOp->inputs[2], tfliteTensors, tfliteModelBuffer, attr->size)) { + if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->size)) { MS_LOG(ERROR) << "get slice -> size failed"; return RET_ERROR; } op->primitive->value.type = schema::PrimitiveType_Slice; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h index 70c1b96da7e..0a84cda6422 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_SLICE_PARSER_H -#define PREDICT_TFLITE_SLICE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SLICE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SLICE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,16 @@ class TfliteSliceParser : public TfliteNodeParser { public: TfliteSliceParser() : TfliteNodeParser("Slice") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_SLICE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SLICE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc index 87277a5872f..5c25d4837a5 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc @@ -14,17 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_softmax_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_softmax_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteSoftmaxParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteSoftmaxParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteSoftmaxParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,19 +40,17 @@ STATUS TfliteSoftmaxParser::Parse(const std::unique_ptr &tfli return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteSoftmaxParser"; std::unique_ptr attr(new schema::SoftMaxT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsSoftmaxOptions(); - if (tflite_attr == nullptr) { - MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; - return RET_NULL_PTR; - } - attr->axis = -1; op->primitive->value.type = schema::PrimitiveType_SoftMax; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h index 685898c429d..73576c110c5 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_CONV_PARSER_H -#define PREDICT_TFLITE_CONV_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,16 @@ class TfliteSoftmaxParser : public TfliteNodeParser { public: TfliteSoftmaxParser() : TfliteNodeParser("Softmax") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_CONV_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc index 9e3fd7db393..bfd9e8248f4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc @@ -15,18 +15,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteSpaceToBatchNDParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteSpaceToBatchNDParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,7 +41,6 @@ STATUS TfliteSpaceToBatchNDParser::Parse(const std::unique_ptr attr(new schema::SpaceToBatchNDT()); if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->blockShape)) { @@ -51,6 +54,11 @@ STATUS TfliteSpaceToBatchNDParser::Parse(const std::unique_ptrprimitive->value.type = schema::PrimitiveType_SpaceToBatchND; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h index 287f492bc6b..e8b5b69a11e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_SPACE_TO_BATCH_ND_PARSER_H -#define LITE_TFLITE_SPACE_TO_BATCH_ND_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPACE_TO_BATCH_ND_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPACE_TO_BATCH_ND_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteSpaceToBatchNDParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_SPACE_TO_BATCH_ND_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPACE_TO_BATCH_ND_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc index 19d33c1bfa0..b9dca6927d0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc @@ -15,18 +15,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_space_to_depth_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_space_to_depth_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteSpaceToDepthParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteSpaceToDepthParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,7 +41,6 @@ STATUS TfliteSpaceToDepthParser::Parse(const std::unique_ptr return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteSpaceToDepthParser"; std::unique_ptr attr(new schema::SpaceToDepthT()); const auto &tflite_attr = tflite_op->builtin_options.AsSpaceToDepthOptions(); @@ -46,11 +49,15 @@ STATUS TfliteSpaceToDepthParser::Parse(const std::unique_ptr return RET_NULL_PTR; } attr->blockSize = tflite_attr->block_size; - attr->format = schema::Format_NHWC; op->primitive->value.type = schema::PrimitiveType_SpaceToDepth; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h index 3adf534253a..be2cc7a16c0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_SPACE_TO_DEPTH_PARSER_H -#define LITE_TFLITE_SPACE_TO_DEPTH_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPACE_TO_DEPTH_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPACE_TO_DEPTH_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteSpaceToDepthParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_SPACE_TO_DEPTH_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPACE_TO_DEPTH_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc index 4a98f90d8f4..8859377b96f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc @@ -15,18 +15,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteSparseToDenseParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteSparseToDenseParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,7 +41,6 @@ STATUS TfliteSparseToDenseParser::Parse(const std::unique_ptr return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteSparseToDenseParser"; std::unique_ptr attr(new schema::SparseToDenseT()); attr->validateIndices = false; @@ -57,6 +60,11 @@ STATUS TfliteSparseToDenseParser::Parse(const std::unique_ptr op->primitive->value.type = schema::PrimitiveType_SparseToDense; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h index ad4a6e02ce7..f4c496f5bff 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_SPARSE_TO_DENSE_PARSER_H -#define LITE_TFLITE_SPARSE_TO_DENSE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPARSE_TO_DENSE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPARSE_TO_DENSE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteSparseToDenseParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_SPARSE_TO_DENSE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPARSE_TO_DENSE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc index 9e90bbfa030..cf3695eeb10 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc @@ -14,19 +14,23 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_split_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_split_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteSplitParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, +STATUS TfliteSplitParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteSplitParser"; + + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,33 +41,32 @@ STATUS TfliteSplitParser::Parse(const std::unique_ptr &tflite return RET_NULL_PTR; } - MS_LOG(INFO) << "parse TfliteSplitParser"; std::unique_ptr attr(new schema::SplitT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsSplitOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsSplitOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; return RET_NULL_PTR; } auto num_splits = tflite_attr->num_splits; - const auto &shape_tensor = tfliteTensors[tfliteOp->inputs[1]]; + const auto &shape_tensor = tflite_tensors[tflite_op->inputs[1]]; if (shape_tensor == nullptr) { MS_LOG(ERROR) << "shape_tensor is null"; return RET_NULL_PTR; } const auto tensor_shape = shape_tensor->shape; - const auto &axis_tensor = tfliteTensors[tfliteOp->inputs[0]]; + const auto &axis_tensor = tflite_tensors[tflite_op->inputs[0]]; if (axis_tensor == nullptr) { MS_LOG(ERROR) << "axis_tensor is null"; return RET_NULL_PTR; } - auto axis = *(reinterpret_cast(tfliteModelBuffer[axis_tensor->buffer]->data.data())); + auto axis = *(reinterpret_cast(tflite_model_buffer[axis_tensor->buffer]->data.data())); if (axis < 0) { axis += tensor_shape.size(); } if (axis >= tensor_shape.size()) { - MS_LOG(ERROR) << "axis value too large"; + MS_LOG(ERROR) << "axis value is too large"; return RET_ERROR; } attr->splitDim = axis; @@ -79,6 +82,13 @@ STATUS TfliteSplitParser::Parse(const std::unique_ptr &tflite op->primitive->value.type = schema::PrimitiveType_Split; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + for (int i = 0; i < tflite_op->outputs.size(); i++) { + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h index 39210e5086e..997fbb01e23 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_SPLIT_PARSER_H -#define PREDICT_TFLITE_SPLIT_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPLIT_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPLIT_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,15 @@ class TfliteSplitParser : public TfliteNodeParser { public: TfliteSplitParser() : TfliteNodeParser("Split") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_SPLIT_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPLIT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc index 0fdf4b3a3f4..f134e085cd4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc @@ -14,17 +14,20 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_split_v_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_split_v_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteSplitVParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteSplitVParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,44 +38,51 @@ STATUS TfliteSplitVParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(INFO) << "parse TfliteSplitVParser"; + MS_LOG(DEBUG) << "parse TfliteSplitVParser"; std::unique_ptr attr(new schema::SplitT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsSplitVOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsSplitVOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name << " attr failed"; return RET_NULL_PTR; } attr->numberSplit = tflite_attr->num_splits; - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->sizeSplits)) { - MS_LOG(ERROR) << "get splite_v -> sizeSplits failed"; + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->sizeSplits)) { + MS_LOG(ERROR) << "get spliteV -> sizeSplits failed"; return RET_ERROR; } - const auto &tensor = tfliteTensors[tfliteOp->inputs[0]]; + const auto &tensor = tflite_tensors[tflite_op->inputs[0]]; if (tensor == nullptr) { MS_LOG(ERROR) << "tensor_shape is null"; return RET_NULL_PTR; } auto tensor_shape = tensor->shape; - const auto &axis_tensor = tfliteTensors[tfliteOp->inputs[2]]; + const auto &axis_tensor = tflite_tensors[tflite_op->inputs[2]]; if (axis_tensor == nullptr) { MS_LOG(ERROR) << "axis_tensor is null"; return RET_NULL_PTR; } - auto axis = *(reinterpret_cast(tfliteModelBuffer[axis_tensor->buffer]->data.data())); + auto axis = *(reinterpret_cast(tflite_model_buffer[axis_tensor->buffer]->data.data())); if (axis < 0) { axis += tensor_shape.size(); } if (axis >= tensor_shape.size()) { - MS_LOG(ERROR) << "axis value too large"; + MS_LOG(ERROR) << "axis value is too large"; return RET_ERROR; } attr->splitDim = axis; op->primitive->value.type = schema::PrimitiveType_Split; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + for (int i = 0; i < tflite_op->outputs.size(); i++) { + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h index c3eefcdec20..125ddbc30d6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_SPLIT_V_PARSER_H -#define PREDICT_TFLITE_SPLIT_V_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPLIT_V_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPLIT_V_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,15 @@ class TfliteSplitVParser : public TfliteNodeParser { public: TfliteSplitVParser() : TfliteNodeParser("SplitV") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_SPLIT_V_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SPLIT_V_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc index bd6ee5f6414..a5720e4a089 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc @@ -14,17 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_squeeze_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_squeeze_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteSqueezeParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteSqueezeParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteSqueezeParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,10 +40,9 @@ STATUS TfliteSqueezeParser::Parse(const std::unique_ptr &tfli return RET_NULL_PTR; } - MS_LOG(INFO) << "parse TfliteSqueezeParser"; std::unique_ptr attr(new schema::SqueezeT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsSqueezeOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsSqueezeOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; @@ -47,6 +51,11 @@ STATUS TfliteSqueezeParser::Parse(const std::unique_ptr &tfli op->primitive->value.type = schema::PrimitiveType_Squeeze; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h index 77387738568..e1b9c9436e0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_SQUEEZE_PARSER_H -#define PREDICT_TFLITE_SQUEEZE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SQUEEZE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SQUEEZE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,15 @@ class TfliteSqueezeParser : public TfliteNodeParser { public: TfliteSqueezeParser() : TfliteNodeParser("Squeeze") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_SQUEEZE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SQUEEZE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc index b6b7da18b12..63d4f0fbfb3 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc @@ -14,17 +14,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_stack_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_stack_parser.h" +#include namespace mindspore { namespace lite { -STATUS TfliteStackParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteStackParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteStackParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,21 +40,26 @@ STATUS TfliteStackParser::Parse(const std::unique_ptr &tflite return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteStackParser"; std::unique_ptr attr(new schema::StackT()); - const auto &tflite_attr = tfliteOp->builtin_options.AsPackOptions(); + const auto &tflite_attr = tflite_op->builtin_options.AsPackOptions(); if (tflite_attr == nullptr) { MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; return RET_NULL_PTR; } - attr->axis = tflite_attr->axis; attr->n = tflite_attr->values_count; - attr->isScale.assign(tfliteTensors[tfliteOp->inputs[0]]->shape.begin(), - tfliteTensors[tfliteOp->inputs[0]]->shape.end()); + attr->isScale.assign(tflite_tensors[tflite_op->inputs[0]]->shape.begin(), + tflite_tensors[tflite_op->inputs[0]]->shape.end()); op->primitive->value.type = schema::PrimitiveType_Stack; op->primitive->value.value = attr.release(); + + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h index db85b07828f..3e6774239aa 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_STACK_PARSER_H -#define PREDICT_TFLITE_STACK_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_STACK_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_STACK_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,16 @@ class TfliteStackParser : public TfliteNodeParser { public: TfliteStackParser() : TfliteNodeParser("Stack") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_STACK_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_STACK_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc index 04f33beb167..123e458665b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc @@ -14,18 +14,20 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_strided_slice_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_strided_slice_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteStridedSliceParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -43,7 +45,6 @@ STATUS TfliteStridedSliceParser::Parse(const std::unique_ptr MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); return RET_NULL_PTR; } - attr->beginMask = tflite_attr->begin_mask; attr->endMask = tflite_attr->end_mask; attr->ellipsisMask = tflite_attr->ellipsis_mask; @@ -67,6 +68,11 @@ STATUS TfliteStridedSliceParser::Parse(const std::unique_ptr op->primitive->value.type = schema::PrimitiveType_StridedSlice; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h index 4a2b1814dbc..9e4db2461f2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h @@ -15,11 +15,12 @@ */ -#ifndef LITE_TFLITE_STRIDED_SLICE_PARSER_H -#define LITE_TFLITE_STRIDED_SLICE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_STRIDED_SLICE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_STRIDED_SLICE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -32,12 +33,13 @@ class TfliteStridedSliceParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_STRIDED_SLICE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_STRIDED_SLICE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc index c73477ef3f9..6b618829b05 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc @@ -15,18 +15,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_tile_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_tile_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteTileParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteTileParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,7 +41,6 @@ STATUS TfliteTileParser::Parse(const std::unique_ptr &tflite_ return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteTileParser"; std::unique_ptr attr(new schema::TileT()); if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->multiples)) { @@ -47,6 +50,11 @@ STATUS TfliteTileParser::Parse(const std::unique_ptr &tflite_ op->primitive->value.type = schema::PrimitiveType_Tile; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h index 48ba12053a4..3534901911a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_TILE_PARSER_H -#define LITE_TFLITE_TILE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TILE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TILE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteTileParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_TILE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TILE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc index cd55c852b48..a89a641d131 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc @@ -15,18 +15,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_topk_v2_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_topk_v2_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteTopKV2Parser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteTopKV2Parser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,9 +41,9 @@ STATUS TfliteTopKV2Parser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteTopKV2Parser"; std::unique_ptr attr(new schema::TopKV2T()); + attr->sorted = true; if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->k)) { MS_LOG(ERROR) << "get topKV2 -> k failed"; return RET_ERROR; @@ -47,6 +51,11 @@ STATUS TfliteTopKV2Parser::Parse(const std::unique_ptr &tflit op->primitive->value.type = schema::PrimitiveType_TopKV2; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h index 3bf9c1dabf5..6ed92506c14 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_TOPK_V2_PARSER_H -#define LITE_TFLITE_TOPK_V2_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TOPK_V2_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TOPK_V2_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteTopKV2Parser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_TOPK_V2_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TOPK_V2_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc index 5f3d90b8898..759c3d8fc52 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc @@ -14,17 +14,21 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_transpose_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_transpose_parser.h" namespace mindspore { namespace lite { -STATUS TfliteTransposeParser::Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { +STATUS TfliteTransposeParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteTransposeParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -35,28 +39,23 @@ STATUS TfliteTransposeParser::Parse(const std::unique_ptr &tf return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteTransposeParser"; std::unique_ptr attr(new schema::TransposeT()); - if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->perm)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->perm)) { MS_LOG(ERROR) << "get transpose -> perm failed"; return RET_ERROR; } - auto weight_index = tfliteOp->inputs[1]; - const auto &weight_tensor = tfliteTensors[weight_index]; - if (weight_tensor == nullptr) { - MS_LOG(ERROR) << "weight_tensor is null"; - return RET_ERROR; - } - std::vector weight_tensors{weight_tensor.get()}; - if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { - MS_LOG(ERROR) << "parse weight failed"; - return RET_ERROR; - } - + attr->conjugate = false; op->primitive->value.type = schema::PrimitiveType_Transpose; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[1], tensors_id->size(), tflite_tensors.size(), schema::Format_KHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h index f92eed0c544..4fc4062713c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef PREDICT_TFLITE_TRANSPOSE_PARSER_H -#define PREDICT_TFLITE_TRANSPOSE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TRANSPOSE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TRANSPOSE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -28,14 +29,16 @@ class TfliteTransposeParser : public TfliteNodeParser { public: TfliteTransposeParser() : TfliteNodeParser("Transpose") {} - STATUS Parse(const std::unique_ptr &tfliteOp, - const std::vector> &tfliteTensors, - const std::vector> &tfliteModelBuffer, - const std::vector> &tfliteOpSet, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantizedModel) override; + STATUS Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tflite_model_buffer, + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // PREDICT_TFLITE_TRANSPOSE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_TRANSPOSE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc index 6860c01bab6..aa052c250d2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc @@ -15,18 +15,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_unique_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_unique_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteUniqueParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteUniqueParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,7 +41,6 @@ STATUS TfliteUniqueParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteUniqueParser"; std::unique_ptr attr(new schema::UniqueT()); const auto &tflite_attr = tflite_op->builtin_options.AsUniqueOptions(); @@ -45,11 +48,17 @@ STATUS TfliteUniqueParser::Parse(const std::unique_ptr &tflit MS_LOG(ERROR) << "get op: %s attr failed", op->name.c_str(); return RET_NULL_PTR; } - - attr->outType = dtype_map[tflite_attr->idx_out_type]; + attr->outType = GetTfliteDataType(tflite_attr->idx_out_type); op->primitive->value.type = schema::PrimitiveType_Unique; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + for (int i = 0; i < tflite_op->outputs.size(); i++) { + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h index 331aa5e48f8..2fadd9aa2f8 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_UNIQUE_PARSER_H -#define LITE_TFLITE_UNIQUE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_UNIQUE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_UNIQUE_PARSER_H #include #include @@ -32,24 +32,12 @@ class TfliteUniqueParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; - - private: - std::map dtype_map = { - {tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64}, - {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, - {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, - {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, - {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, - {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, - {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, - {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, - {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, - }; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_UNIQUE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_UNIQUE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc index a2ea8b5a6be..b78f4fe7991 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc @@ -15,17 +15,23 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_unstack_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_unstack_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteUnstackParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, - schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) { + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "paser TfliteUnstackParser"; + + // set attr if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -36,7 +42,6 @@ STATUS TfliteUnstackParser::Parse(const std::unique_ptr &tfli return RET_NULL_PTR; } - MS_LOG(DEBUG) << "paser TfliteUnstackParser"; std::unique_ptr attr(new schema::UnstackT()); const auto &tflite_attr = tflite_op->builtin_options.AsUnpackOptions(); @@ -49,6 +54,13 @@ STATUS TfliteUnstackParser::Parse(const std::unique_ptr &tfli op->primitive->value.type = schema::PrimitiveType_Unstack; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + for (int i = 0; i < tflite_op->outputs.size(); i++) { + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h index 82729e7f38e..28fed8b7148 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_UNSTACK_PARSER_H -#define LITE_TFLITE_UNSTACK_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_UNSTACK_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_UNSTACK_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteUnstackParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_UNSTACK_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_UNSTACK_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc index d8358a8e5e0..3a4f1b8a9ed 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc @@ -15,25 +15,15 @@ */ #include "tools/converter/parser/tflite/tflite_util.h" -#include #include #include +#include +#include #include "utils/log_adapter.h" #include "include/errorcode.h" namespace mindspore { namespace lite { -std::map tfMsActivationFunctionMap{ - {tflite::ActivationFunctionType_NONE, schema::ActivationType_NO_ACTIVATION}, - {tflite::ActivationFunctionType_RELU, schema::ActivationType_RELU}, - {tflite::ActivationFunctionType_RELU6, schema::ActivationType_RELU6}, - {tflite::ActivationFunctionType_TANH, schema::ActivationType_TANH}, -}; - -schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType) { - return tfMsActivationFunctionMap.at(tfliteAFType); -} - std::map tfMsOpTypeMap{ {tflite::BuiltinOperator_CONV_2D, "Conv2D"}, {tflite::BuiltinOperator_DEPTHWISE_CONV_2D, "DepthwiseConv2D"}, @@ -129,6 +119,29 @@ std::map tfMsOpTypeMap{ {tflite::BuiltinOperator_UNPACK, "Unstack"}, }; +std::map tfMsActivationFunctionMap{ + {tflite::ActivationFunctionType_NONE, schema::ActivationType_NO_ACTIVATION}, + {tflite::ActivationFunctionType_RELU, schema::ActivationType_RELU}, + {tflite::ActivationFunctionType_RELU6, schema::ActivationType_RELU6}, + {tflite::ActivationFunctionType_TANH, schema::ActivationType_TANH}, +}; + +std::map type_map = { + {tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64}, + {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, + {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, + {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, + {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, + {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, + {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, + {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, + {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, +}; + +schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType) { + return tfMsActivationFunctionMap.at(tfliteAFType); +} + std::string GetMSOpType(tflite::BuiltinOperator tfliteOpType) { auto iter = tfMsOpTypeMap.find(tfliteOpType); if (iter == tfMsOpTypeMap.end()) { @@ -138,16 +151,6 @@ std::string GetMSOpType(tflite::BuiltinOperator tfliteOpType) { return iter->second; } -std::map type_map = { - {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, - {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, - {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, - {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, - {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, - {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, - {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, -}; - TypeId GetTfliteDataType(const tflite::TensorType &tflite_data_type) { auto iter = type_map.find(tflite_data_type); if (iter == type_map.end()) { @@ -183,12 +186,48 @@ size_t GetDataTypeSize(const TypeId &data_type) { case TypeId::kNumberTypeInt64: return sizeof(int64_t); default: - MS_LOG(ERROR) << data_type; - MS_LOG(ERROR) << "Unsupported datatype"; + MS_LOG(ERROR) << data_type << " is Unsupported datatype"; return RET_ERROR; } } +STATUS getPaddingParam(const std::unique_ptr &tensor, + schema::PadMode pad_mode, + int strideH, int strideW, + int windowH, int windowW, + std::vector *params) { + if (tensor == nullptr) { + MS_LOG(ERROR) << "the input tensor is null"; + return RET_ERROR; + } + + int padUp = 0; + int padDown = 0; + int padLeft = 0; + int padRight = 0; + if (pad_mode == schema::PadMode_SAME) { + auto shape = tensor->shape; + int H_input = shape.at(1); + int W_input = shape.at(2); + + int H_output = ceil(H_input * 1.0 / strideH); + int pad_needed_H = (H_output - 1) * strideH + windowH - H_input; + padUp = floor(pad_needed_H / 2.0); + padDown = pad_needed_H - padUp; + + int W_output = ceil(W_input * 1.0 / strideW); + int pad_needed_W = (W_output - 1) * strideW + windowW - W_input; + padLeft = floor(pad_needed_W / 2.0); + padRight = pad_needed_W - padLeft; + } + + params->emplace_back(padUp); + params->emplace_back(padDown); + params->emplace_back(padLeft); + params->emplace_back(padRight); + return RET_OK; +} + void Split(const std::string &src_str, std::vector *dst_str, const std::string &chr) { std::string ::size_type p1 = 0, p2 = src_str.find(chr); while (std::string::npos != p2) { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.h b/mindspore/lite/tools/converter/parser/tflite/tflite_util.h index ee78e96a36a..9dc0bba97d9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_util.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.h @@ -19,11 +19,14 @@ #include #include +#include +#include #include "utils/log_adapter.h" #include "schema/inner/model_generated.h" #include "tools/converter/parser/tflite/schema_generated.h" #include "schema/inner/ops_generated.h" #include "ir/dtype/type_id.h" +#include "include/errorcode.h" namespace mindspore { namespace lite { @@ -37,7 +40,15 @@ std::string GetMSOpType(tflite::BuiltinOperator tfliteOpType); TypeId GetTfliteDataType(const tflite::TensorType &tflite_data_type); -void Split(const std::string &src_str, std::vector *dst_str, const std::string &chr); +STATUS getPaddingParam(const std::unique_ptr &tensor, + schema::PadMode pad_mode, + int strideH, int strideW, + int windowH, int windowW, + std::vector *params); + +void Split(const std::string &src_str, + std::vector *dst_str, + const std::string &chr); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc index d13b35a6d54..88d027e41a2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc @@ -15,18 +15,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_where_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_where_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteWhereParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteWhereParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,7 +41,6 @@ STATUS TfliteWhereParser::Parse(const std::unique_ptr &tflite return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteWhereParser"; std::unique_ptr attr(new schema::WhereT()); if (GetTfliteData(tflite_op->inputs[0], tflite_tensors, tflite_model_buffer, attr->condition)) { @@ -47,6 +50,13 @@ STATUS TfliteWhereParser::Parse(const std::unique_ptr &tflite op->primitive->value.type = schema::PrimitiveType_Where; op->primitive->value.value = attr.release(); + + for (int i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + } + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h index 3a707d0a8b2..583b8dffe6e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_WHERE_PARSER_H -#define LITE_TFLITE_WHERE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_WHERE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_WHERE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteWhereParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_WHERE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_WHERE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc index fb88ba965b4..d5c1750d598 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc @@ -15,18 +15,22 @@ * limitations under the License. */ +#include "tools/converter/parser/tflite/tflite_zeros_like_parser.h" #include #include -#include "tools/converter/parser/tflite/tflite_zeros_like_parser.h" +#include namespace mindspore { namespace lite { STATUS TfliteZerosLikeParser::Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, bool quantized_model) { + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) { + MS_LOG(DEBUG) << "parse TfliteZerosLikeParser"; + if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -37,11 +41,15 @@ STATUS TfliteZerosLikeParser::Parse(const std::unique_ptr &tf return RET_NULL_PTR; } - MS_LOG(DEBUG) << "parse TfliteZerosLikeParser"; std::unique_ptr attr(new schema::ZerosLikeT()); op->primitive->value.type = schema::PrimitiveType_ZerosLike; op->primitive->value.value = attr.release(); + + AddOpInput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); + AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, + tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h index 0c656137d58..a8cec073fd6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h @@ -14,11 +14,12 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_ZEROS_LIKE_PARSER_H -#define LITE_TFLITE_ZEROS_LIKE_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ZEROS_LIKE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ZEROS_LIKE_PARSER_H #include #include +#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -31,11 +32,12 @@ class TfliteZerosLikeParser : public TfliteNodeParser { STATUS Parse(const std::unique_ptr &tflite_op, const std::vector> &tflite_tensors, const std::vector> &tflite_model_buffer, - const std::vector> &tflite_opset, schema::CNodeT *op, - TensorCache *tensor_cache, - bool quantized_model) override; + schema::CNodeT *op, + std::vector *tensors_id, + std::vector *tensors_format, + std::map *tensors_id_map) override; }; } // namespace lite } // namespace mindspore -#endif // LITE_TFLITE_ZEROS_LIKE_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ZEROS_LIKE_PARSER_H