From 174d308cfad0f0b1243b53194ca7cd245e69d465 Mon Sep 17 00:00:00 2001 From: hangq Date: Wed, 29 Jul 2020 02:16:33 +0800 Subject: [PATCH] add mindspore lite --- .gitignore | 16 +- .gitmodules | 8 + build.sh | 376 +- mindspore/core/ir/CMakeLists.txt | 4 - mindspore/core/ir/lite/tensor.cc | 88 - mindspore/core/ir/lite/tensor.h | 61 - mindspore/core/ir/meta_tensor.cc | 2 - mindspore/core/ir/meta_tensor_extends.cc | 2 + mindspore/core/ir/param_value.h | 2 +- mindspore/core/utils/log_adapter.cc | 28 + mindspore/core/utils/log_adapter.h | 4 + mindspore/lite/CMakeLists.txt | 119 + mindspore/lite/build.sh | 272 ++ mindspore/lite/cmake-build-cloud/Lite.cbp | 4235 +++++++++++++++++ .../googletest/CTestTestfile.cmake | 7 + .../googletest/googlemock/CTestTestfile.cmake | 7 + .../googletest/googlemock/gmock.cbp | 517 ++ .../googlemock/gtest/CTestTestfile.cmake | 6 + .../gtest/generated/GTestConfig.cmake | 33 + .../gtest/generated/GTestConfigVersion.cmake | 37 + .../googlemock/gtest/generated/gmock.pc | 9 + .../googlemock/gtest/generated/gmock_main.pc | 9 + .../googlemock/gtest/generated/gtest.pc | 9 + .../googlemock/gtest/generated/gtest_main.pc | 10 + .../googletest/googlemock/gtest/gtest.cbp | 327 ++ .../googletest/googletest-distribution.cbp | 517 ++ .../runtime/kernel/arm/opclib/optimize.cbp | 112 + mindspore/lite/cmake-build-minnie/Lite.cbp | 4235 +++++++++++++++++ .../googletest/CTestTestfile.cmake | 7 + .../googletest/googlemock/CTestTestfile.cmake | 7 + .../googletest/googlemock/gmock.cbp | 517 ++ .../googlemock/gtest/CTestTestfile.cmake | 6 + .../gtest/generated/GTestConfig.cmake | 33 + .../gtest/generated/GTestConfigVersion.cmake | 37 + .../googlemock/gtest/generated/gmock.pc | 9 + .../googlemock/gtest/generated/gmock_main.pc | 9 + .../googlemock/gtest/generated/gtest.pc | 9 + .../googlemock/gtest/generated/gtest_main.pc | 10 + .../googletest/googlemock/gtest/gtest.cbp | 327 ++ .../googletest/googletest-distribution.cbp | 517 ++ .../runtime/kernel/arm/opclib/optimize.cbp | 112 + mindspore/lite/include/context.h | 76 + mindspore/lite/include/errorcode.h | 55 + mindspore/lite/include/lite_session.h | 52 + mindspore/lite/include/model.h | 57 + mindspore/lite/include/ms_tensor.h | 70 + mindspore/lite/schema/model.fbs | 208 + mindspore/lite/schema/ops.fbs | 719 +++ mindspore/lite/src/CMakeLists.txt | 83 + .../src/common/anf_exporter/CMakeLists.txt | 7 + .../src/common/anf_exporter/anf_exporter.cc | 263 + .../src/common/anf_exporter/anf_exporter.h | 46 + .../anf_populater/anf_activation_populater.cc | 42 + .../anf_populater/anf_activation_populater.h | 30 + .../anf_populater/anf_batchnorm_populater.cc | 37 + .../anf_populater/anf_batchnorm_populater.h | 29 + .../anf_populater/anf_biasadd_populater.cc | 37 + .../anf_populater/anf_biasadd_populater.h | 29 + .../anf_populater/anf_conv_populater.cc | 121 + .../anf_populater/anf_conv_populater.h | 32 + .../anf_populater/anf_flatten_populater.cc | 35 + .../anf_populater/anf_flatten_populater.h | 29 + .../anf_populater/anf_matmul_populater.cc | 38 + .../anf_populater/anf_matmul_populater.h | 29 + .../anf_populater/anf_mul_populater.cc | 35 + .../anf_populater/anf_mul_populater.h | 29 + .../anf_populater/anf_node_populater.cc | 19 + .../anf_populater/anf_node_populater.h | 33 + .../anf_node_populater_registry.cc | 48 + .../anf_node_populater_registry.h | 43 + .../anf_populater/anf_pool_populater.cc | 68 + .../anf_populater/anf_pool_populater.h | 29 + .../anf_populater/anf_reducemean_populater.cc | 40 + .../anf_populater/anf_reducemean_populater.h | 29 + .../anf_populater/anf_tensoradd_populater.cc | 34 + .../anf_populater/anf_tensoradd_populater.h | 29 + .../anf_tuple_getitem_populater.cc | 34 + .../anf_tuple_getitem_populater.h | 29 + .../src/common/anf_importer/anf_importer.cc | 184 + .../src/common/anf_importer/anf_importer.h | 54 + .../anf_importer/import_from_meta_graph.cc | 122 + .../anf_importer/import_from_meta_graph.h | 47 + .../anf_importer/import_from_meta_graphT.cc | 123 + .../anf_importer/import_from_meta_graphT.h | 49 + .../anf_importer/import_from_protobuf.cc | 717 +++ .../anf_importer/import_from_protobuf.h | 92 + mindspore/lite/src/common/common.h | 64 + mindspore/lite/src/common/file_utils.cc | 168 + mindspore/lite/src/common/file_utils.h | 58 + mindspore/lite/src/common/graph_util.cc | 77 + mindspore/lite/src/common/graph_util.h | 250 + .../lite/src/common/graph_utils_extends.cc | 151 + mindspore/lite/src/common/op_utils.h | 32 + mindspore/lite/src/common/utils.cc | 262 + mindspore/lite/src/common/utils.h | 193 + mindspore/lite/src/context.cc | 31 + mindspore/lite/src/executor.cc | 124 + mindspore/lite/src/executor.h | 48 + mindspore/lite/src/gllo/common/node_pass.cc | 68 + mindspore/lite/src/gllo/common/node_pass.h | 36 + mindspore/lite/src/gllo/common/optimizer.cc | 117 + mindspore/lite/src/gllo/common/optimizer.h | 90 + mindspore/lite/src/gllo/common/pass.h | 41 + .../lite/src/gllo/common/pass_manager.cc | 89 + mindspore/lite/src/gllo/common/pass_manager.h | 61 + .../lite/src/gllo/common/pattern_engine.cc | 365 ++ .../lite/src/gllo/common/pattern_engine.h | 203 + mindspore/lite/src/gllo/common/utils.cc | 207 + mindspore/lite/src/gllo/common/utils.h | 48 + mindspore/lite/src/gllo/common/visit.cc | 165 + mindspore/lite/src/gllo/common/visit.h | 59 + .../src/gllo/fusion/conv_biasadd_fusion.cc | 73 + .../src/gllo/fusion/conv_biasadd_fusion.h | 34 + mindspore/lite/src/ir/meta_tensor_extends.cc | 28 + mindspore/lite/src/ir/primitive_t_value.cc | 17 + mindspore/lite/src/ir/primitive_t_value.h | 76 + mindspore/lite/src/ir/primitive_value.cc | 19 + mindspore/lite/src/ir/primitive_value.h | 47 + mindspore/lite/src/ir/tensor.cc | 323 ++ mindspore/lite/src/ir/tensor.h | 224 + mindspore/lite/src/kernel_factory.cc | 61 + mindspore/lite/src/kernel_factory.h | 41 + mindspore/lite/src/kernel_registry.cc | 59 + mindspore/lite/src/kernel_registry.h | 61 + mindspore/lite/src/lite_kernel.cc | 144 + mindspore/lite/src/lite_kernel.h | 182 + mindspore/lite/src/lite_session.cc | 283 ++ mindspore/lite/src/lite_session.h | 78 + mindspore/lite/src/model.cc | 53 + mindspore/lite/src/model_impl.cc | 187 + mindspore/lite/src/model_impl.h | 54 + mindspore/lite/src/ops/CMakeLists.txt | 3 + mindspore/lite/src/ops/addn.cc | 44 + mindspore/lite/src/ops/argmax.cc | 48 + mindspore/lite/src/ops/argmin.cc | 47 + mindspore/lite/src/ops/arithmetic.cc | 102 + mindspore/lite/src/ops/arithmetic_self.cc | 34 + mindspore/lite/src/ops/batch_to_space.cc | 93 + mindspore/lite/src/ops/broadcast_to.cc | 66 + mindspore/lite/src/ops/cast.cc | 52 + mindspore/lite/src/ops/concat.cc | 76 + mindspore/lite/src/ops/conv.cc | 86 + .../lite/src/ops/convolution_depthwise.cc | 80 + mindspore/lite/src/ops/crop.cc | 38 + mindspore/lite/src/ops/deconvolution.cc | 73 + .../lite/src/ops/deconvolution_depthwise.cc | 69 + mindspore/lite/src/ops/depth_to_space.cc | 62 + mindspore/lite/src/ops/expand_dims.cc | 51 + mindspore/lite/src/ops/fill.cc | 48 + mindspore/lite/src/ops/flatten.cc | 49 + mindspore/lite/src/ops/fullconnection.cc | 62 + mindspore/lite/src/ops/gather.cc | 77 + mindspore/lite/src/ops/gather_nd.cc | 65 + mindspore/lite/src/ops/matmul.cc | 63 + mindspore/lite/src/ops/nchw2nhwc.cc | 42 + mindspore/lite/src/ops/nhwc2nchw.cc | 42 + mindspore/lite/src/ops/one_hot.cc | 72 + mindspore/lite/src/ops/ops.cc | 144 + mindspore/lite/src/ops/ops.h | 666 +++ mindspore/lite/src/ops/pad.cc | 63 + mindspore/lite/src/ops/pooling.cc | 82 + mindspore/lite/src/ops/range.cc | 40 + mindspore/lite/src/ops/rank.cc | 35 + mindspore/lite/src/ops/reduce.cc | 78 + mindspore/lite/src/ops/reshape.cc | 120 + mindspore/lite/src/ops/resize.cc | 51 + mindspore/lite/src/ops/reverse_sequence.cc | 35 + mindspore/lite/src/ops/scatter_nd.cc | 63 + mindspore/lite/src/ops/shape.cc | 64 + mindspore/lite/src/ops/slice.cc | 68 + mindspore/lite/src/ops/softmax.cc | 34 + mindspore/lite/src/ops/split.cc | 62 + mindspore/lite/src/ops/squeeze.cc | 75 + mindspore/lite/src/ops/stack.cc | 67 + mindspore/lite/src/ops/strided_slice.cc | 162 + mindspore/lite/src/ops/tile.cc | 45 + mindspore/lite/src/ops/topk.cc | 48 + mindspore/lite/src/ops/transpose.cc | 53 + mindspore/lite/src/ops/unique.cc | 42 + mindspore/lite/src/ops/unsqueeze.cc | 73 + mindspore/lite/src/ops/unstack.cc | 48 + mindspore/lite/src/ops/where.cc | 79 + mindspore/lite/src/ops/zeroslike.cc | 39 + mindspore/lite/src/param_value_lite.h | 79 + mindspore/lite/src/populate_parameter.cc | 1036 ++++ mindspore/lite/src/populate_parameter.h | 28 + mindspore/lite/src/runtime/allocator.cc | 123 + mindspore/lite/src/runtime/allocator.h | 79 + .../src/runtime/kernel/arm/CMakeLists.txt | 34 + .../runtime/kernel/arm/base/concat_base.cc | 107 + .../src/runtime/kernel/arm/base/concat_base.h | 53 + .../kernel/arm/base/convolution_base.cc | 502 ++ .../kernel/arm/base/convolution_base.h | 71 + .../kernel/arm/base/fullconnection_base.cc | 79 + .../kernel/arm/base/fullconnection_base.h | 49 + .../kernel/arm/base/layout_transform.cc | 71 + .../kernel/arm/base/layout_transform.h | 40 + .../src/runtime/kernel/arm/base/matrix.cc | 83 + .../lite/src/runtime/kernel/arm/base/matrix.h | 98 + .../runtime/kernel/arm/base/pooling_base.cc | 150 + .../runtime/kernel/arm/base/pooling_base.h | 52 + .../runtime/kernel/arm/base/reshape_base.cc | 106 + .../runtime/kernel/arm/base/reshape_base.h | 48 + .../kernel/arm/fp16/convolution_3x3_fp16.cc | 252 + .../kernel/arm/fp16/convolution_3x3_fp16.h | 79 + .../kernel/arm/fp16/convolution_fp16.cc | 221 + .../kernel/arm/fp16/convolution_fp16.h | 75 + .../src/runtime/kernel/arm/fp32/activation.cc | 106 + .../src/runtime/kernel/arm/fp32/activation.h | 47 + .../lite/src/runtime/kernel/arm/fp32/addn.cc | 74 + .../lite/src/runtime/kernel/arm/fp32/addn.h | 39 + .../src/runtime/kernel/arm/fp32/argminmax.cc | 100 + .../src/runtime/kernel/arm/fp32/argminmax.h | 41 + .../src/runtime/kernel/arm/fp32/arithmetic.cc | 168 + .../src/runtime/kernel/arm/fp32/arithmetic.h | 145 + .../kernel/arm/fp32/arithmetic_self.cc | 116 + .../runtime/kernel/arm/fp32/arithmetic_self.h | 108 + .../runtime/kernel/arm/fp32/batch_to_space.cc | 89 + .../runtime/kernel/arm/fp32/batch_to_space.h | 42 + .../lite/src/runtime/kernel/arm/fp32/bias.cc | 94 + .../lite/src/runtime/kernel/arm/fp32/bias.h | 43 + .../runtime/kernel/arm/fp32/broadcast_to.cc | 78 + .../runtime/kernel/arm/fp32/broadcast_to.h | 42 + .../lite/src/runtime/kernel/arm/fp32/cast.cc | 114 + .../lite/src/runtime/kernel/arm/fp32/cast.h | 49 + .../src/runtime/kernel/arm/fp32/concat.cc | 71 + .../lite/src/runtime/kernel/arm/fp32/concat.h | 45 + .../runtime/kernel/arm/fp32/convolution.cc | 201 + .../src/runtime/kernel/arm/fp32/convolution.h | 59 + .../kernel/arm/fp32/convolution_1x1.cc | 231 + .../runtime/kernel/arm/fp32/convolution_1x1.h | 69 + .../kernel/arm/fp32/convolution_3x3.cc | 237 + .../runtime/kernel/arm/fp32/convolution_3x3.h | 70 + .../kernel/arm/fp32/convolution_depthwise.cc | 149 + .../kernel/arm/fp32/convolution_depthwise.h | 59 + .../kernel/arm/fp32/convolution_winograd.cc | 338 ++ .../kernel/arm/fp32/convolution_winograd.h | 74 + .../lite/src/runtime/kernel/arm/fp32/crop.cc | 93 + .../lite/src/runtime/kernel/arm/fp32/crop.h | 48 + .../runtime/kernel/arm/fp32/deconvolution.cc | 227 + .../runtime/kernel/arm/fp32/deconvolution.h | 63 + .../arm/fp32/deconvolution_depthwise.cc | 162 + .../kernel/arm/fp32/deconvolution_depthwise.h | 56 + .../runtime/kernel/arm/fp32/depth_to_space.cc | 88 + .../runtime/kernel/arm/fp32/depth_to_space.h | 39 + .../src/runtime/kernel/arm/fp32/expandDims.cc | 97 + .../src/runtime/kernel/arm/fp32/expandDims.h | 54 + .../lite/src/runtime/kernel/arm/fp32/fill.cc | 107 + .../lite/src/runtime/kernel/arm/fp32/fill.h | 52 + .../src/runtime/kernel/arm/fp32/flatten.cc | 71 + .../src/runtime/kernel/arm/fp32/flatten.h | 47 + .../runtime/kernel/arm/fp32/fullconnection.cc | 118 + .../runtime/kernel/arm/fp32/fullconnection.h | 51 + .../kernel/arm/fp32/fused_batchnorm.cc | 68 + .../runtime/kernel/arm/fp32/fused_batchnorm.h | 45 + .../src/runtime/kernel/arm/fp32/gather.cc | 126 + .../lite/src/runtime/kernel/arm/fp32/gather.h | 45 + .../src/runtime/kernel/arm/fp32/gatherNd.cc | 148 + .../src/runtime/kernel/arm/fp32/gatherNd.h | 56 + .../kernel/arm/fp32/local_response_norm.cc | 110 + .../kernel/arm/fp32/local_response_norm.h | 47 + .../src/runtime/kernel/arm/fp32/matmul.cc | 53 + .../lite/src/runtime/kernel/arm/fp32/matmul.h | 45 + .../src/runtime/kernel/arm/fp32/nchw2nhwc.cc | 58 + .../src/runtime/kernel/arm/fp32/nchw2nhwc.h | 42 + .../src/runtime/kernel/arm/fp32/nhwc2nchw.cc | 58 + .../src/runtime/kernel/arm/fp32/nhwc2nchw.h | 42 + .../src/runtime/kernel/arm/fp32/one_hot.cc | 187 + .../src/runtime/kernel/arm/fp32/one_hot.h | 49 + .../lite/src/runtime/kernel/arm/fp32/pad.cc | 187 + .../lite/src/runtime/kernel/arm/fp32/pad.h | 59 + .../src/runtime/kernel/arm/fp32/pooling.cc | 78 + .../src/runtime/kernel/arm/fp32/pooling.h | 50 + .../lite/src/runtime/kernel/arm/fp32/power.cc | 80 + .../lite/src/runtime/kernel/arm/fp32/power.h | 50 + .../lite/src/runtime/kernel/arm/fp32/prelu.cc | 91 + .../lite/src/runtime/kernel/arm/fp32/prelu.h | 54 + .../lite/src/runtime/kernel/arm/fp32/range.cc | 70 + .../lite/src/runtime/kernel/arm/fp32/range.h | 39 + .../lite/src/runtime/kernel/arm/fp32/rank.cc | 68 + .../lite/src/runtime/kernel/arm/fp32/rank.h | 39 + .../src/runtime/kernel/arm/fp32/reduce.cc | 247 + .../lite/src/runtime/kernel/arm/fp32/reduce.h | 83 + .../src/runtime/kernel/arm/fp32/reshape.cc | 46 + .../src/runtime/kernel/arm/fp32/reshape.h | 45 + .../src/runtime/kernel/arm/fp32/resize.cc | 242 + .../lite/src/runtime/kernel/arm/fp32/resize.h | 65 + .../src/runtime/kernel/arm/fp32/reverse.cc | 161 + .../src/runtime/kernel/arm/fp32/reverse.h | 63 + .../kernel/arm/fp32/reverse_sequence.cc | 116 + .../kernel/arm/fp32/reverse_sequence.h | 43 + .../lite/src/runtime/kernel/arm/fp32/scale.cc | 168 + .../lite/src/runtime/kernel/arm/fp32/scale.h | 52 + .../src/runtime/kernel/arm/fp32/scatter_nd.cc | 186 + .../src/runtime/kernel/arm/fp32/scatter_nd.h | 52 + .../lite/src/runtime/kernel/arm/fp32/shape.cc | 84 + .../lite/src/runtime/kernel/arm/fp32/shape.h | 42 + .../lite/src/runtime/kernel/arm/fp32/slice.cc | 83 + .../lite/src/runtime/kernel/arm/fp32/slice.h | 39 + .../src/runtime/kernel/arm/fp32/softmax.cc | 79 + .../src/runtime/kernel/arm/fp32/softmax.h | 42 + .../kernel/arm/fp32/sparse_to_dense.cc | 102 + .../runtime/kernel/arm/fp32/sparse_to_dense.h | 59 + .../lite/src/runtime/kernel/arm/fp32/split.cc | 130 + .../lite/src/runtime/kernel/arm/fp32/split.h | 49 + .../src/runtime/kernel/arm/fp32/squeeze.cc | 79 + .../src/runtime/kernel/arm/fp32/squeeze.h | 43 + .../lite/src/runtime/kernel/arm/fp32/stack.cc | 113 + .../lite/src/runtime/kernel/arm/fp32/stack.h | 54 + .../runtime/kernel/arm/fp32/strided_slice.cc | 86 + .../runtime/kernel/arm/fp32/strided_slice.h | 44 + .../lite/src/runtime/kernel/arm/fp32/tile.cc | 82 + .../lite/src/runtime/kernel/arm/fp32/tile.h | 41 + .../lite/src/runtime/kernel/arm/fp32/topk.cc | 73 + .../lite/src/runtime/kernel/arm/fp32/topk.h | 43 + .../src/runtime/kernel/arm/fp32/transpose.cc | 100 + .../src/runtime/kernel/arm/fp32/transpose.h | 44 + .../src/runtime/kernel/arm/fp32/unique.cc | 67 + .../lite/src/runtime/kernel/arm/fp32/unique.h | 40 + .../src/runtime/kernel/arm/fp32/unsqueeze.cc | 96 + .../src/runtime/kernel/arm/fp32/unsqueeze.h | 51 + .../src/runtime/kernel/arm/fp32/unstack.cc | 91 + .../src/runtime/kernel/arm/fp32/unstack.h | 43 + .../lite/src/runtime/kernel/arm/fp32/where.cc | 110 + .../lite/src/runtime/kernel/arm/fp32/where.h | 56 + .../src/runtime/kernel/arm/fp32/zeroslike.cc | 68 + .../src/runtime/kernel/arm/fp32/zeroslike.h | 39 + .../src/runtime/kernel/arm/int8/add_int8.cc | 146 + .../src/runtime/kernel/arm/int8/add_int8.h | 51 + .../runtime/kernel/arm/int8/bias_add_int8.cc | 84 + .../runtime/kernel/arm/int8/bias_add_int8.h | 42 + .../runtime/kernel/arm/int8/concat_int8.cc | 144 + .../src/runtime/kernel/arm/int8/concat_int8.h | 45 + .../kernel/arm/int8/convolution_3x3_int8.cc | 241 + .../kernel/arm/int8/convolution_3x3_int8.h | 54 + .../arm/int8/convolution_depthwise_int8.cc | 146 + .../arm/int8/convolution_depthwise_int8.h | 56 + .../kernel/arm/int8/convolution_int8.cc | 388 ++ .../kernel/arm/int8/convolution_int8.h | 75 + .../arm/int8/deconvolution_depthwise_int8.cc | 174 + .../arm/int8/deconvolution_depthwise_int8.h | 58 + .../kernel/arm/int8/deconvolution_int8.cc | 220 + .../kernel/arm/int8/deconvolution_int8.h | 63 + .../kernel/arm/int8/fullconnection_int8.cc | 89 + .../kernel/arm/int8/fullconnection_int8.h | 51 + .../src/runtime/kernel/arm/int8/mul_int8.cc | 132 + .../src/runtime/kernel/arm/int8/mul_int8.h | 51 + .../runtime/kernel/arm/int8/pooling_int8.cc | 87 + .../runtime/kernel/arm/int8/pooling_int8.h | 45 + .../runtime/kernel/arm/int8/reshape_int8.cc | 74 + .../runtime/kernel/arm/int8/reshape_int8.h | 47 + .../runtime/kernel/arm/opclib/CMakeLists.txt | 37 + .../src/runtime/kernel/arm/opclib/add_int8.cc | 123 + .../src/runtime/kernel/arm/opclib/add_int8.h | 54 + .../kernel/arm/opclib/arithmetic_common.cc | 99 + .../kernel/arm/opclib/arithmetic_common.h | 52 + .../kernel/arm/opclib/arithmetic_parameter.h | 26 + .../assembly/arm32/IndirectGemmFp32_8x4.S | 294 ++ .../arm32/IndirectGemmInt16to32_8x4.S | 240 + .../assembly/arm32/IndirectGemmInt8_2x4.S | 229 + .../assembly/arm64/IndirectGemmFp16_16x8.S | 720 +++ .../assembly/arm64/IndirectGemmFp32_8x8.S | 730 +++ .../arm64/IndirectGemmInt16to32_8x4.S | 221 + .../assembly/arm64/IndirectGemmInt8_4x4.S | 326 ++ .../arm/opclib/assembly/arm64/bias_add.S | 82 + .../arm/opclib/assembly/arm64/bias_add_relu.S | 94 + .../opclib/assembly/arm64/bias_add_relu6.S | 113 + .../kernel/arm/opclib/assembly/arm64/matmul.s | 315 ++ .../arm/opclib/assembly/arm64/matrix_add.S | 103 + .../arm/opclib/assembly/arm64/matrix_sub.S | 105 + .../kernel/arm/opclib/assembly/arm64/relu.S | 73 + .../kernel/arm/opclib/assembly/arm64/relu6.S | 89 + .../assembly/opt/IndirectGemmInt8_24x4_dp.S | 636 +++ .../runtime/kernel/arm/opclib/common_func.cc | 169 + .../runtime/kernel/arm/opclib/common_func.h | 57 + .../kernel/arm/opclib/concat_parameter.h | 29 + .../kernel/arm/opclib/conv_parameter.h | 59 + .../src/runtime/kernel/arm/opclib/errorcode.h | 47 + .../src/runtime/kernel/arm/opclib/flatten.cc | 22 + .../src/runtime/kernel/arm/opclib/flatten.h | 27 + .../kernel/arm/opclib/fp16/conv_fp16.cc | 219 + .../kernel/arm/opclib/fp16/conv_fp16.h | 42 + .../kernel/arm/opclib/fp32/activation.h | 78 + .../kernel/arm/opclib/fp32/arg_min_max.cc | 82 + .../kernel/arm/opclib/fp32/arg_min_max.h | 37 + .../kernel/arm/opclib/fp32/arithmetic.cc | 526 ++ .../kernel/arm/opclib/fp32/arithmetic.h | 98 + .../kernel/arm/opclib/fp32/arithmetic_self.cc | 123 + .../kernel/arm/opclib/fp32/arithmetic_self.h | 56 + .../kernel/arm/opclib/fp32/batch_to_space.cc | 94 + .../kernel/arm/opclib/fp32/batch_to_space.h | 33 + .../kernel/arm/opclib/fp32/broadcast_to.cc | 108 + .../kernel/arm/opclib/fp32/broadcast_to.h | 41 + .../runtime/kernel/arm/opclib/fp32/cast.cc | 55 + .../src/runtime/kernel/arm/opclib/fp32/cast.h | 40 + .../kernel/arm/opclib/fp32/common_func.cc | 105 + .../kernel/arm/opclib/fp32/common_func.h | 52 + .../runtime/kernel/arm/opclib/fp32/concat.cc | 44 + .../runtime/kernel/arm/opclib/fp32/concat.h | 25 + .../runtime/kernel/arm/opclib/fp32/conv.cc | 194 + .../src/runtime/kernel/arm/opclib/fp32/conv.h | 51 + .../kernel/arm/opclib/fp32/conv_depthwise.cc | 351 ++ .../kernel/arm/opclib/fp32/conv_depthwise.h | 49 + .../runtime/kernel/arm/opclib/fp32/crop.cc | 59 + .../src/runtime/kernel/arm/opclib/fp32/crop.h | 30 + .../runtime/kernel/arm/opclib/fp32/deconv.cc | 78 + .../runtime/kernel/arm/opclib/fp32/deconv.h | 33 + .../kernel/arm/opclib/fp32/depth_to_space.cc | 43 + .../kernel/arm/opclib/fp32/depth_to_space.h | 29 + .../kernel/arm/opclib/fp32/expandDims.cc | 25 + .../kernel/arm/opclib/fp32/expandDims.h | 30 + .../runtime/kernel/arm/opclib/fp32/fill.cc | 25 + .../src/runtime/kernel/arm/opclib/fp32/fill.h | 36 + .../runtime/kernel/arm/opclib/fp32/gather.cc | 43 + .../runtime/kernel/arm/opclib/fp32/gather.h | 32 + .../kernel/arm/opclib/fp32/gatherNd.cc | 28 + .../runtime/kernel/arm/opclib/fp32/gatherNd.h | 30 + .../arm/opclib/fp32/local_response_norm.cc | 42 + .../arm/opclib/fp32/local_response_norm.h | 34 + .../runtime/kernel/arm/opclib/fp32/matmul.cc | 78 + .../runtime/kernel/arm/opclib/fp32/matmul.h | 39 + .../runtime/kernel/arm/opclib/fp32/one_hot.cc | 49 + .../runtime/kernel/arm/opclib/fp32/one_hot.h | 37 + .../runtime/kernel/arm/opclib/fp32/pooling.cc | 210 + .../runtime/kernel/arm/opclib/fp32/pooling.h | 56 + .../runtime/kernel/arm/opclib/fp32/range.cc | 25 + .../runtime/kernel/arm/opclib/fp32/range.h | 33 + .../runtime/kernel/arm/opclib/fp32/rank.cc | 22 + .../src/runtime/kernel/arm/opclib/fp32/rank.h | 24 + .../runtime/kernel/arm/opclib/fp32/reduce.cc | 146 + .../runtime/kernel/arm/opclib/fp32/reduce.h | 42 + .../runtime/kernel/arm/opclib/fp32/reverse.cc | 28 + .../runtime/kernel/arm/opclib/fp32/reverse.h | 36 + .../runtime/kernel/arm/opclib/fp32/slice.cc | 80 + .../runtime/kernel/arm/opclib/fp32/slice.h | 34 + .../runtime/kernel/arm/opclib/fp32/softmax.cc | 60 + .../runtime/kernel/arm/opclib/fp32/softmax.h | 34 + .../runtime/kernel/arm/opclib/fp32/stack.cc | 44 + .../runtime/kernel/arm/opclib/fp32/stack.h | 28 + .../kernel/arm/opclib/fp32/strassen_matmul.cc | 208 + .../kernel/arm/opclib/fp32/strassen_matmul.h | 40 + .../kernel/arm/opclib/fp32/strided_slice.cc | 82 + .../kernel/arm/opclib/fp32/strided_slice.h | 32 + .../kernel/arm/opclib/fp32/unsqueeze.cc | 25 + .../kernel/arm/opclib/fp32/unsqueeze.h | 34 + .../kernel/arm/opclib/fused_batchnorm.cc | 34 + .../kernel/arm/opclib/fused_batchnorm.h | 33 + .../kernel/arm/opclib/int8/concat_int8.cc | 64 + .../kernel/arm/opclib/int8/concat_int8.h | 25 + .../arm/opclib/int8/conv_depthwise_int8.cc | 322 ++ .../arm/opclib/int8/conv_depthwise_int8.h | 30 + .../kernel/arm/opclib/int8/conv_int8.cc | 338 ++ .../kernel/arm/opclib/int8/conv_int8.h | 57 + .../runtime/kernel/arm/opclib/int8/deconv.cc | 68 + .../runtime/kernel/arm/opclib/int8/deconv.h | 34 + .../runtime/kernel/arm/opclib/int8/matmul.cc | 101 + .../runtime/kernel/arm/opclib/int8/matmul.h | 42 + .../kernel/arm/opclib/int8/mul_int8.cc | 88 + .../runtime/kernel/arm/opclib/int8/mul_int8.h | 25 + .../kernel/arm/opclib/int8/pooling_int8.cc | 372 ++ .../kernel/arm/opclib/int8/pooling_int8.h | 34 + .../kernel/arm/opclib/int8/reshape_int8.cc | 41 + .../kernel/arm/opclib/int8/reshape_int8.h | 25 + .../src/runtime/kernel/arm/opclib/matmul.h | 37 + .../runtime/kernel/arm/opclib/matrix_table.h | 512 ++ .../runtime/kernel/arm/opclib/mul_parameter.h | 28 + .../runtime/kernel/arm/opclib/offset_utils.h | 30 + .../src/runtime/kernel/arm/opclib/op_base.h | 58 + .../runtime/kernel/arm/opclib/opclib_utils.cc | 28 + .../runtime/kernel/arm/opclib/opclib_utils.h | 27 + .../kernel/arm/opclib/opt_op_handler.c | 31 + .../kernel/arm/opclib/optimized_kernel.h | 64 + .../src/runtime/kernel/arm/opclib/pack.cc | 1093 +++++ .../lite/src/runtime/kernel/arm/opclib/pack.h | 166 + .../lite/src/runtime/kernel/arm/opclib/pad.cc | 37 + .../lite/src/runtime/kernel/arm/opclib/pad.h | 35 + .../src/runtime/kernel/arm/opclib/power.h | 36 + .../src/runtime/kernel/arm/opclib/prelu.cc | 27 + .../src/runtime/kernel/arm/opclib/prelu.h | 31 + .../arm/opclib/quantization/fixed_point.h | 687 +++ .../arm/opclib/quantization/quantize.cc | 77 + .../kernel/arm/opclib/quantization/quantize.h | 110 + .../src/runtime/kernel/arm/opclib/reshape.cc | 22 + .../src/runtime/kernel/arm/opclib/reshape.h | 24 + .../kernel/arm/opclib/reshape_parameter.h | 28 + .../src/runtime/kernel/arm/opclib/resize.cc | 136 + .../src/runtime/kernel/arm/opclib/resize.h | 45 + .../kernel/arm/opclib/reverse_sequence.cc | 42 + .../kernel/arm/opclib/reverse_sequence.h | 42 + .../src/runtime/kernel/arm/opclib/scale.cc | 53 + .../src/runtime/kernel/arm/opclib/scale.h | 35 + .../runtime/kernel/arm/opclib/scatter_nd.cc | 31 + .../runtime/kernel/arm/opclib/scatter_nd.h | 28 + .../src/runtime/kernel/arm/opclib/shape.h | 27 + .../kernel/arm/opclib/sparse_to_dense.cc | 35 + .../kernel/arm/opclib/sparse_to_dense.h | 31 + .../src/runtime/kernel/arm/opclib/split.cc | 59 + .../src/runtime/kernel/arm/opclib/split.h | 36 + .../src/runtime/kernel/arm/opclib/squeeze.cc | 27 + .../src/runtime/kernel/arm/opclib/squeeze.h | 30 + .../kernel/arm/opclib/strassen_matmul.h | 34 + .../src/runtime/kernel/arm/opclib/tile.cc | 47 + .../lite/src/runtime/kernel/arm/opclib/tile.h | 35 + .../src/runtime/kernel/arm/opclib/topk.cc | 55 + .../lite/src/runtime/kernel/arm/opclib/topk.h | 39 + .../runtime/kernel/arm/opclib/transpose.cc | 125 + .../src/runtime/kernel/arm/opclib/transpose.h | 39 + .../src/runtime/kernel/arm/opclib/unique.cc | 40 + .../src/runtime/kernel/arm/opclib/unique.h | 29 + .../src/runtime/kernel/arm/opclib/unstack.cc | 31 + .../src/runtime/kernel/arm/opclib/unstack.h | 34 + .../src/runtime/kernel/arm/opclib/where.cc | 27 + .../src/runtime/kernel/arm/opclib/where.h | 33 + .../kernel/arm/opclib/winograd_transform.cc | 1883 ++++++++ .../kernel/arm/opclib/winograd_transform.h | 86 + .../kernel/arm/opclib/winograd_utils.cc | 3804 +++++++++++++++ .../kernel/arm/opclib/winograd_utils.h | 58 + .../runtime/kernel/arm/opclib/zeroslike.cc | 21 + .../src/runtime/kernel/arm/opclib/zeroslike.h | 24 + .../src/runtime/kernel/opencl/CMakeLists.txt | 12 + .../opencl/cl/fp16/conv2d_transpose2x2.cl | 52 + .../kernel/opencl/cl/fp16/depthwise_conv2d.cl | 96 + .../runtime/kernel/opencl/cl/fp16/matmul.cl | 32 + .../kernel/opencl/cl/fp32/arithmetic.cl | 49 + .../kernel/opencl/cl/fp32/avg_pool2d.cl | 66 + .../runtime/kernel/opencl/cl/fp32/concat.cl | 60 + .../opencl/cl/fp32/conv2d_transpose2x2.cl | 51 + .../kernel/opencl/cl/fp32/convolution.cl | 87 + .../kernel/opencl/cl/fp32/depthwise_conv2d.cl | 95 + .../runtime/kernel/opencl/cl/fp32/matmul.cl | 31 + .../kernel/opencl/cl/fp32/max_pool2d.cl | 68 + .../runtime/kernel/opencl/cl/fp32/softmax.cl | 35 + .../src/runtime/kernel/opencl/image_format.h | 64 + .../kernel/opencl/kernel/arithmetic.cc | 132 + .../runtime/kernel/opencl/kernel/arithmetic.h | 45 + .../runtime/kernel/opencl/kernel/concat.cc | 136 + .../src/runtime/kernel/opencl/kernel/concat.h | 57 + .../kernel/opencl/kernel/conv2d_transpose.cc | 180 + .../kernel/opencl/kernel/conv2d_transpose.h | 56 + .../kernel/opencl/kernel/convolution.cc | 202 + .../kernel/opencl/kernel/convolution.h | 48 + .../kernel/opencl/kernel/depthwise_conv2d.cc | 150 + .../kernel/opencl/kernel/depthwise_conv2d.h | 50 + .../runtime/kernel/opencl/kernel/matmul.cc | 151 + .../src/runtime/kernel/opencl/kernel/matmul.h | 60 + .../runtime/kernel/opencl/kernel/pooling2d.cc | 141 + .../runtime/kernel/opencl/kernel/pooling2d.h | 52 + .../runtime/kernel/opencl/kernel/softmax.cc | 101 + .../runtime/kernel/opencl/kernel/softmax.h | 51 + .../kernel/opencl/subgraph_opencl_kernel.cc | 85 + .../kernel/opencl/subgraph_opencl_kernel.h | 55 + .../lite/src/runtime/kernel/opencl/utils.cc | 174 + .../lite/src/runtime/kernel/opencl/utils.h | 88 + .../lite/src/runtime/opencl/CMakeLists.txt | 11 + .../src/runtime/opencl/opencl_allocator.cc | 212 + .../src/runtime/opencl/opencl_allocator.h | 76 + .../src/runtime/opencl/opencl_executor.cc | 140 + .../lite/src/runtime/opencl/opencl_executor.h | 53 + .../lite/src/runtime/opencl/opencl_runtime.cc | 609 +++ .../lite/src/runtime/opencl/opencl_runtime.h | 155 + .../lite/src/runtime/opencl/opencl_wrapper.cc | 683 +++ .../lite/src/runtime/opencl/opencl_wrapper.h | 240 + mindspore/lite/src/runtime/runtime_api.cc | 105 + mindspore/lite/src/runtime/runtime_api.h | 57 + mindspore/lite/src/runtime/thread_pool.cc | 456 ++ mindspore/lite/src/runtime/thread_pool.h | 126 + mindspore/lite/src/runtime/workspace_pool.cc | 143 + mindspore/lite/src/runtime/workspace_pool.h | 45 + mindspore/lite/src/scheduler.cc | 173 + mindspore/lite/src/scheduler.h | 55 + mindspore/lite/src/train/base_ref_utils.cc | 60 + mindspore/lite/src/train/base_ref_utils.h | 31 + mindspore/lite/src/train/import.hpp | 46 + .../lite/src/train/lite_kernel_runtime.cc | 83 + .../lite/src/train/lite_kernel_runtime.h | 57 + mindspore/lite/src/train/model_impl.cc | 119 + mindspore/lite/src/train/model_impl.h | 52 + mindspore/lite/src/train/train_session.cc | 232 + mindspore/lite/src/train/train_session.h | 65 + mindspore/lite/test/CMakeLists.txt | 298 ++ mindspore/lite/test/benchmark_test.cc | 47 + mindspore/lite/test/converter_test.cc | 58 + mindspore/lite/tools/benchmark/CMakeLists.txt | 22 + mindspore/lite/tools/benchmark/benchmark.cc | 531 +++ mindspore/lite/tools/benchmark/benchmark.h | 146 + mindspore/lite/tools/benchmark/main.cc | 20 + mindspore/lite/tools/common/CMakeLists.txt | 9 + .../lite/tools/common/converter_op_utils.h | 34 + mindspore/lite/tools/common/flag_parser.cc | 180 + mindspore/lite/tools/common/flag_parser.h | 301 ++ mindspore/lite/tools/common/graph_util.cc | 671 +++ mindspore/lite/tools/common/graph_util.h | 107 + mindspore/lite/tools/common/node_util.cc | 178 + mindspore/lite/tools/common/node_util.h | 373 ++ mindspore/lite/tools/common/option.h | 120 + mindspore/lite/tools/common/storage.cc | 65 + mindspore/lite/tools/common/storage.h | 38 + mindspore/lite/tools/common/tensor_util.cc | 191 + mindspore/lite/tools/common/tensor_util.h | 123 + mindspore/lite/tools/converter/CMakeLists.txt | 109 + .../lite/tools/converter/anf_transform.cc | 45 + .../lite/tools/converter/anf_transform.h | 43 + mindspore/lite/tools/converter/converter.cc | 195 + mindspore/lite/tools/converter/converter.h | 52 + .../lite/tools/converter/converter_flags.cc | 176 + .../lite/tools/converter/converter_flags.h | 88 + .../tools/converter/graphdef_transform.cc | 183 + .../lite/tools/converter/graphdef_transform.h | 51 + mindspore/lite/tools/converter/main.cc | 20 + mindspore/lite/tools/converter/model_parser.h | 62 + mindspore/lite/tools/converter/optimizer.cc | 81 + mindspore/lite/tools/converter/optimizer.h | 86 + .../tools/converter/optimizer/CMakeLists.txt | 6 + .../optimizer/const_fold/CMakeLists.txt | 50 + .../const_fold/add_const_fold_pass.cc | 98 + .../const_fold/add_const_fold_pass.h | 41 + .../const_fold/cast_const_fold_pass.cc | 68 + .../const_fold/cast_const_fold_pass.h | 40 + .../const_fold/concat_v2_const_fold_pass.cc | 66 + .../const_fold/concat_v2_const_fold_pass.h | 110 + .../optimizer/const_fold/const_fold_pass.cc | 207 + .../optimizer/const_fold/const_fold_pass.h | 64 + .../const_fold/expand_dims_const_fold_pass.cc | 66 + .../const_fold/expand_dims_const_fold_pass.h | 40 + .../const_fold/mul_const_fold_pass.cc | 101 + .../const_fold/mul_const_fold_pass.h | 41 + .../const_fold/range_const_fold_pass.cc | 68 + .../const_fold/range_const_fold_pass.h | 41 + .../const_fold/reshape_const_fold_pass.cc | 66 + .../const_fold/reshape_const_fold_pass.h | 43 + .../const_fold/rsqrt_const_fold_pass.cc | 67 + .../const_fold/rsqrt_const_fold_pass.h | 41 + .../const_fold/shape_const_fold_pass.cc | 65 + .../const_fold/shape_const_fold_pass.h | 40 + .../const_fold/slice_const_fold_pass.cc | 66 + .../const_fold/slice_const_fold_pass.h | 41 + .../const_fold/stack_const_fold_pass.cc | 65 + .../const_fold/stack_const_fold_pass.h | 42 + .../strided_slice_const_fold_pass.cc | 65 + .../strided_slice_const_fold_pass.h | 41 + .../const_fold/sub_const_fold_pass.cc | 101 + .../const_fold/sub_const_fold_pass.h | 41 + .../const_fold/tile_const_fold_pass.cc | 66 + .../const_fold/tile_const_fold_pass.h | 42 + .../const_fold/transpose_const_fold_pass.cc | 67 + .../const_fold/transpose_const_fold_pass.h | 41 + .../converter/optimizer/fusion/CMakeLists.txt | 17 + .../fusion/batchnorm_fold_fusion_pass.cc | 500 ++ .../fusion/batchnorm_fold_fusion_pass.h | 87 + .../fusion/conv_activation_fusion_pass.cc | 101 + .../fusion/conv_activation_fusion_pass.h | 50 + .../fusion/conv_biasadd_fusion_pass.cc | 288 ++ .../fusion/conv_biasadd_fusion_pass.h | 51 + .../optimizer/fusion/conv_bn_fusion_pass.cc | 224 + .../optimizer/fusion/conv_bn_fusion_pass.h | 54 + .../fusion/conv_relu6_fusion_pass.cc | 41 + .../optimizer/fusion/conv_relu6_fusion_pass.h | 46 + .../optimizer/fusion/conv_relu_fusion_pass.cc | 40 + .../optimizer/fusion/conv_relu_fusion_pass.h | 45 + .../fusion/conv_scale_bias_fusion_pass.cc | 361 ++ .../fusion/conv_scale_bias_fusion_pass.h | 67 + .../fusion/conv_scale_fusion_pass.cc | 126 + .../optimizer/fusion/conv_scale_fusion_pass.h | 46 + .../fusion/format_trans_fusion_pass.cc | 185 + .../fusion/format_trans_fusion_pass.h | 52 + .../converter/optimizer/fusion/fusion_pass.cc | 349 ++ .../converter/optimizer/fusion/fusion_pass.h | 87 + .../optimizer/fusion/fusion_pattern.cc | 182 + .../optimizer/fusion/fusion_pattern.h | 141 + .../fusion/matmul_biasadd_fusion_pass.cc | 225 + .../fusion/matmul_biasadd_fusion_pass.h | 84 + .../fusion/quant_cast_fusion_pass.cc | 139 + .../optimizer/fusion/quant_cast_fusion_pass.h | 51 + .../converter/optimizer/graph/CMakeLists.txt | 7 + .../optimizer/graph/format_trans_pass.cc | 200 + .../optimizer/graph/format_trans_pass.h | 57 + .../graph/isolated_node_remove_pass.cc | 46 + .../graph/isolated_node_remove_pass.h | 37 + .../model_input_format_preprocess_pass.cc | 46 + .../model_input_format_preprocess_pass.h | 38 + .../optimizer/graph/topological_sort_pass.cc | 82 + .../optimizer/graph/topological_sort_pass.h} | 31 +- .../graph/unused_node_remove_pass.cc | 50 + .../optimizer/graph/unused_node_remove_pass.h | 37 + .../converter/optimizer/node/CMakeLists.txt | 3 + .../optimizer/node/weight_format_pass.cc | 394 ++ .../optimizer/node/weight_format_pass.h | 58 + .../converter/parser/caffe/CMakeLists.txt | 52 + .../tools/converter/parser/caffe/caffe.proto | 1675 +++++++ .../parser/caffe/caffe_argmax_parser.cc | 58 + .../parser/caffe/caffe_argmax_parser.h | 37 + .../parser/caffe/caffe_batchnorm_parser.cc | 111 + .../parser/caffe/caffe_batchnorm_parser.h | 37 + .../parser/caffe/caffe_concat_parser.cc | 65 + .../parser/caffe/caffe_concat_parser.h | 37 + .../parser/caffe/caffe_conv_base_parser.cc | 218 + .../parser/caffe/caffe_conv_base_parser.h | 53 + .../converter/parser/caffe/caffe_converter.cc | 27 + .../converter/parser/caffe/caffe_converter.h | 36 + .../parser/caffe/caffe_convolution_parser.cc | 119 + .../parser/caffe/caffe_convolution_parser.h | 41 + .../parser/caffe/caffe_crop_parser.cc | 62 + .../parser/caffe/caffe_crop_parser.h | 37 + .../caffe/caffe_deconvolution_parser.cc | 118 + .../parser/caffe/caffe_deconvolution_parser.h | 41 + .../parser/caffe/caffe_eltwise_parser.cc | 72 + .../parser/caffe/caffe_eltwise_parser.h | 37 + .../parser/caffe/caffe_innerproduct_parser.cc | 75 + .../parser/caffe/caffe_innerproduct_parser.h | 37 + .../converter/parser/caffe/caffe_inspector.cc | 79 + .../converter/parser/caffe/caffe_inspector.h | 56 + .../parser/caffe/caffe_interp_parser.cc | 58 + .../parser/caffe/caffe_interp_parser.h | 37 + .../parser/caffe/caffe_model_parser.cc | 307 ++ .../parser/caffe/caffe_model_parser.h | 64 + .../parser/caffe/caffe_node_parser.cc | 102 + .../parser/caffe/caffe_node_parser.h | 51 + .../caffe/caffe_node_parser_registry.cc | 39 + .../parser/caffe/caffe_node_parser_registry.h | 48 + .../parser/caffe/caffe_parse_utils.cc | 103 + .../parser/caffe/caffe_parse_utils.h | 40 + .../parser/caffe/caffe_pooling_parser.cc | 155 + .../parser/caffe/caffe_pooling_parser.h | 45 + .../parser/caffe/caffe_power_parser.cc | 50 + .../parser/caffe/caffe_power_parser.h | 37 + .../parser/caffe/caffe_prelu_parser.cc | 55 + .../parser/caffe/caffe_prelu_parser.h | 37 + .../parser/caffe/caffe_relu_parser.cc | 50 + .../parser/caffe/caffe_relu_parser.h | 37 + .../parser/caffe/caffe_reshape_parser.cc | 49 + .../parser/caffe/caffe_reshape_parser.h | 37 + .../parser/caffe/caffe_scale_parser.cc | 97 + .../parser/caffe/caffe_scale_parser.h | 39 + .../parser/caffe/caffe_sigmoid_parser.cc | 37 + .../parser/caffe/caffe_sigmoid_parser.h | 37 + .../parser/caffe/caffe_softmax_parser.cc | 47 + .../parser/caffe/caffe_softmax_parser.h | 37 + .../converter/parser/onnx/CMakeLists.txt | 5 + .../tools/converter/parser/onnx/onnx.proto | 569 +++ .../parser/onnx/onnx_argmax_parser.cc | 45 + .../parser/onnx/onnx_argmax_parser.h | 33 + .../onnx/onnx_arithmetic_operation_parser.cc | 270 ++ .../onnx/onnx_arithmetic_operation_parser.h | 171 + .../parser/onnx/onnx_batchnorm_parser.cc | 44 + .../parser/onnx/onnx_batchnorm_parser.h | 33 + .../parser/onnx/onnx_biasadd_parser.cc | 42 + .../parser/onnx/onnx_biasadd_parser.h | 34 + .../converter/parser/onnx/onnx_cast_parser.cc | 41 + .../converter/parser/onnx/onnx_cast_parser.h | 33 + .../converter/parser/onnx/onnx_clip_parser.cc | 43 + .../converter/parser/onnx/onnx_clip_parser.h | 33 + .../parser/onnx/onnx_concat_parser.cc | 43 + .../parser/onnx/onnx_concat_parser.h | 33 + .../parser/onnx/onnx_constant_parser.cc | 36 + .../parser/onnx/onnx_constant_parser.h | 33 + .../converter/parser/onnx/onnx_conv_parser.cc | 172 + .../converter/parser/onnx/onnx_conv_parser.h | 36 + .../converter/parser/onnx/onnx_converter.cc | 26 + .../converter/parser/onnx/onnx_converter.h | 37 + .../parser/onnx/onnx_deconv_parser.cc | 154 + .../parser/onnx/onnx_deconv_parser.h | 36 + .../parser/onnx/onnx_depth_to_space_parser.cc | 43 + .../parser/onnx/onnx_depth_to_space_parser.h | 33 + .../parser/onnx/onnx_dropout_parser.cc | 43 + .../parser/onnx/onnx_dropout_parser.h | 33 + .../converter/parser/onnx/onnx_elu_parser.cc | 41 + .../converter/parser/onnx/onnx_elu_parser.h | 33 + .../parser/onnx/onnx_expand_parser.cc | 36 + .../parser/onnx/onnx_expand_parser.h | 33 + .../parser/onnx/onnx_flatten_parser.cc | 49 + .../parser/onnx/onnx_flatten_parser.h | 34 + .../parser/onnx/onnx_gather_parser.cc | 43 + .../parser/onnx/onnx_gather_parser.h | 33 + .../converter/parser/onnx/onnx_lrn_parser.cc | 47 + .../converter/parser/onnx/onnx_lrn_parser.h | 33 + .../parser/onnx/onnx_matmul_parser.cc | 56 + .../parser/onnx/onnx_matmul_parser.h | 33 + .../parser/onnx/onnx_model_parser.cc | 512 ++ .../converter/parser/onnx/onnx_model_parser.h | 80 + .../converter/parser/onnx/onnx_node_parser.cc | 35 + .../converter/parser/onnx/onnx_node_parser.h | 43 + .../parser/onnx/onnx_node_parser_registry.cc | 45 + .../parser/onnx/onnx_node_parser_registry.h | 49 + .../converter/parser/onnx/onnx_pad_parser.cc | 55 + .../converter/parser/onnx/onnx_pad_parser.h | 33 + .../converter/parser/onnx/onnx_pool_parser.cc | 92 + .../converter/parser/onnx/onnx_pool_parser.h | 33 + .../parser/onnx/onnx_reduce_parser.cc | 65 + .../parser/onnx/onnx_reduce_parser.h | 33 + .../converter/parser/onnx/onnx_relu_parser.cc | 81 + .../converter/parser/onnx/onnx_relu_parser.h | 44 + .../parser/onnx/onnx_reshape_parser.cc | 62 + .../parser/onnx/onnx_reshape_parser.h | 33 + .../parser/onnx/onnx_shape_parser.cc | 36 + .../converter/parser/onnx/onnx_shape_parser.h | 33 + .../parser/onnx/onnx_sigmoid_parser.cc | 38 + .../parser/onnx/onnx_sigmoid_parser.h | 33 + .../parser/onnx/onnx_slice_parser.cc | 51 + .../converter/parser/onnx/onnx_slice_parser.h | 33 + .../parser/onnx/onnx_softmax_parser.cc | 43 + .../parser/onnx/onnx_softmax_parser.h | 33 + .../parser/onnx/onnx_space_to_depth_parser.cc | 43 + .../parser/onnx/onnx_space_to_depth_parser.h | 33 + .../parser/onnx/onnx_squeeze_parser.cc | 45 + .../parser/onnx/onnx_squeeze_parser.h | 33 + .../converter/parser/onnx/onnx_tile_parser.cc | 34 + .../converter/parser/onnx/onnx_tile_parser.h | 33 + .../parser/onnx/onnx_transpose_parser.cc | 53 + .../parser/onnx/onnx_transpose_parser.h | 33 + .../parser/onnx/onnx_unsample_parser.cc | 48 + .../parser/onnx/onnx_unsample_parser.h | 33 + .../parser/onnx/onnx_unsqueeze_parser.cc | 45 + .../parser/onnx/onnx_unsqueeze_parser.h | 33 + .../parser/onnx/onnx_unuseful_node_parser.cc | 52 + .../parser/onnx/onnx_unuseful_node_parser.h | 33 + .../converter/parser/tflite/CMakeLists.txt | 6 + .../tools/converter/parser/tflite/schema.fbs | 926 ++++ .../parser/tflite/tflite_add_parser.cc | 44 + .../parser/tflite/tflite_add_parser.h | 42 + .../parser/tflite/tflite_argmax_parser.cc | 43 + .../parser/tflite/tflite_argmax_parser.h | 41 + .../parser/tflite/tflite_concat_parser.cc | 53 + .../parser/tflite/tflite_concat_parser.h | 42 + .../parser/tflite/tflite_conv_parser.cc | 82 + .../parser/tflite/tflite_conv_parser.h | 42 + .../parser/tflite/tflite_converter.cc | 26 + .../parser/tflite/tflite_converter.h | 38 + .../tflite/tflite_depthwise_conv_parser.cc | 140 + .../tflite/tflite_depthwise_conv_parser.h | 47 + .../parser/tflite/tflite_fakequant_parser.cc | 59 + .../parser/tflite/tflite_fakequant_parser.h | 39 + .../tflite/tflite_fullyconnected_parser.cc | 61 + .../tflite/tflite_fullyconnected_parser.h | 41 + .../parser/tflite/tflite_logistic_parser.cc | 46 + .../parser/tflite/tflite_logistic_parser.h | 42 + .../tflite/tflite_max_pooling_parser.cc | 57 + .../parser/tflite/tflite_max_pooling_parser.h | 42 + .../parser/tflite/tflite_mean_parser.cc | 49 + .../parser/tflite/tflite_mean_parser.h | 41 + .../tflite/tflite_mean_pooling_parser.cc | 56 + .../tflite/tflite_mean_pooling_parser.h | 42 + .../parser/tflite/tflite_model_parser.cc | 251 + .../parser/tflite/tflite_model_parser.h | 86 + .../parser/tflite/tflite_mul_parser.cc | 58 + .../parser/tflite/tflite_mul_parser.h | 42 + .../parser/tflite/tflite_node_parser.cc | 102 + .../parser/tflite/tflite_node_parser.h | 129 + .../tflite/tflite_node_parser_registry.cc | 39 + .../tflite/tflite_node_parser_registry.h | 50 + .../parser/tflite/tflite_relu6_parser.cc | 41 + .../parser/tflite/tflite_relu6_parser.h | 41 + .../parser/tflite/tflite_reshape_parser.cc | 54 + .../parser/tflite/tflite_reshape_parser.h | 41 + .../tflite/tflite_resize_bilinear_parser.cc | 58 + .../tflite/tflite_resize_bilinear_parser.h | 42 + .../parser/tflite/tflite_rsqrt_parser.cc | 41 + .../parser/tflite/tflite_rsqrt_parser.h | 41 + .../parser/tflite/tflite_slice_parser.cc | 50 + .../parser/tflite/tflite_slice_parser.h | 41 + .../parser/tflite/tflite_softmax_parser.cc | 47 + .../parser/tflite/tflite_softmax_parser.h | 41 + .../tflite/tflite_squareddifference_parser.cc | 45 + .../tflite/tflite_squareddifference_parser.h | 41 + .../parser/tflite/tflite_stack_parser.cc | 50 + .../parser/tflite/tflite_stack_parser.h | 41 + .../parser/tflite/tflite_sub_parser.cc | 60 + .../parser/tflite/tflite_sub_parser.h | 42 + .../parser/tflite/tflite_tanh_parser.cc | 42 + .../parser/tflite/tflite_tanh_parser.h | 42 + .../parser/tflite/tflite_transpose_parser.cc | 59 + .../parser/tflite/tflite_transpose_parser.h | 41 + .../converter/parser/tflite/tflite_util.cc | 111 + .../converter/parser/tflite/tflite_util.h | 45 + .../tools/converter/quantizer/CMakeLists.txt | 19 + .../converter/quantizer/general_bitpacking.cc | 86 + .../converter/quantizer/general_bitpacking.h | 43 + .../converter/quantizer/post_training.cc | 926 ++++ .../tools/converter/quantizer/post_training.h | 159 + .../converter/quantizer/quantize_util.cc | 343 ++ .../tools/converter/quantizer/quantize_util.h | 107 + .../tools/converter/quantizer/quantizer.cc | 36 + .../tools/converter/quantizer/quantizer.h | 63 + .../converter/quantizer/weight_quantizer.cc | 151 + .../converter/quantizer/weight_quantizer.h | 53 + tests/ut/cpp/common/common_test.h | 44 + 885 files changed, 97615 insertions(+), 300 deletions(-) delete mode 100644 mindspore/core/ir/lite/tensor.cc delete mode 100644 mindspore/core/ir/lite/tensor.h create mode 100644 mindspore/lite/CMakeLists.txt create mode 100755 mindspore/lite/build.sh create mode 100644 mindspore/lite/cmake-build-cloud/Lite.cbp create mode 100644 mindspore/lite/cmake-build-cloud/googletest/CTestTestfile.cmake create mode 100644 mindspore/lite/cmake-build-cloud/googletest/googlemock/CTestTestfile.cmake create mode 100644 mindspore/lite/cmake-build-cloud/googletest/googlemock/gmock.cbp create mode 100644 mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/CTestTestfile.cmake create mode 100644 mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/GTestConfig.cmake create mode 100644 mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/GTestConfigVersion.cmake create mode 100644 mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/gmock.pc create mode 100644 mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/gmock_main.pc create mode 100644 mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/gtest.pc create mode 100644 mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/gtest_main.pc create mode 100644 mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/gtest.cbp create mode 100644 mindspore/lite/cmake-build-cloud/googletest/googletest-distribution.cbp create mode 100644 mindspore/lite/cmake-build-cloud/src/runtime/kernel/arm/opclib/optimize.cbp create mode 100644 mindspore/lite/cmake-build-minnie/Lite.cbp create mode 100644 mindspore/lite/cmake-build-minnie/googletest/CTestTestfile.cmake create mode 100644 mindspore/lite/cmake-build-minnie/googletest/googlemock/CTestTestfile.cmake create mode 100644 mindspore/lite/cmake-build-minnie/googletest/googlemock/gmock.cbp create mode 100644 mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/CTestTestfile.cmake create mode 100644 mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/GTestConfig.cmake create mode 100644 mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/GTestConfigVersion.cmake create mode 100644 mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/gmock.pc create mode 100644 mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/gmock_main.pc create mode 100644 mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/gtest.pc create mode 100644 mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/gtest_main.pc create mode 100644 mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/gtest.cbp create mode 100644 mindspore/lite/cmake-build-minnie/googletest/googletest-distribution.cbp create mode 100644 mindspore/lite/cmake-build-minnie/src/runtime/kernel/arm/opclib/optimize.cbp create mode 100644 mindspore/lite/include/context.h create mode 100644 mindspore/lite/include/errorcode.h create mode 100644 mindspore/lite/include/lite_session.h create mode 100644 mindspore/lite/include/model.h create mode 100644 mindspore/lite/include/ms_tensor.h create mode 100644 mindspore/lite/schema/model.fbs create mode 100644 mindspore/lite/schema/ops.fbs create mode 100644 mindspore/lite/src/CMakeLists.txt create mode 100644 mindspore/lite/src/common/anf_exporter/CMakeLists.txt create mode 100644 mindspore/lite/src/common/anf_exporter/anf_exporter.cc create mode 100644 mindspore/lite/src/common/anf_exporter/anf_exporter.h create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.cc create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.h create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.cc create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.cc create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.h create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.cc create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.h create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.cc create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.h create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.cc create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.h create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.cc create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.h create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.cc create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.h create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.cc create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.h create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.cc create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.h create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.cc create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.h create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.cc create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.h create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.cc create mode 100644 mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.h create mode 100644 mindspore/lite/src/common/anf_importer/anf_importer.cc create mode 100644 mindspore/lite/src/common/anf_importer/anf_importer.h create mode 100644 mindspore/lite/src/common/anf_importer/import_from_meta_graph.cc create mode 100644 mindspore/lite/src/common/anf_importer/import_from_meta_graph.h create mode 100644 mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc create mode 100644 mindspore/lite/src/common/anf_importer/import_from_meta_graphT.h create mode 100644 mindspore/lite/src/common/anf_importer/import_from_protobuf.cc create mode 100644 mindspore/lite/src/common/anf_importer/import_from_protobuf.h create mode 100755 mindspore/lite/src/common/common.h create mode 100644 mindspore/lite/src/common/file_utils.cc create mode 100644 mindspore/lite/src/common/file_utils.h create mode 100755 mindspore/lite/src/common/graph_util.cc create mode 100755 mindspore/lite/src/common/graph_util.h create mode 100644 mindspore/lite/src/common/graph_utils_extends.cc create mode 100755 mindspore/lite/src/common/op_utils.h create mode 100644 mindspore/lite/src/common/utils.cc create mode 100644 mindspore/lite/src/common/utils.h create mode 100644 mindspore/lite/src/context.cc create mode 100644 mindspore/lite/src/executor.cc create mode 100644 mindspore/lite/src/executor.h create mode 100644 mindspore/lite/src/gllo/common/node_pass.cc create mode 100644 mindspore/lite/src/gllo/common/node_pass.h create mode 100644 mindspore/lite/src/gllo/common/optimizer.cc create mode 100644 mindspore/lite/src/gllo/common/optimizer.h create mode 100644 mindspore/lite/src/gllo/common/pass.h create mode 100644 mindspore/lite/src/gllo/common/pass_manager.cc create mode 100644 mindspore/lite/src/gllo/common/pass_manager.h create mode 100644 mindspore/lite/src/gllo/common/pattern_engine.cc create mode 100644 mindspore/lite/src/gllo/common/pattern_engine.h create mode 100644 mindspore/lite/src/gllo/common/utils.cc create mode 100644 mindspore/lite/src/gllo/common/utils.h create mode 100644 mindspore/lite/src/gllo/common/visit.cc create mode 100644 mindspore/lite/src/gllo/common/visit.h create mode 100644 mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.cc create mode 100644 mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h create mode 100644 mindspore/lite/src/ir/meta_tensor_extends.cc create mode 100644 mindspore/lite/src/ir/primitive_t_value.cc create mode 100644 mindspore/lite/src/ir/primitive_t_value.h create mode 100644 mindspore/lite/src/ir/primitive_value.cc create mode 100644 mindspore/lite/src/ir/primitive_value.h create mode 100644 mindspore/lite/src/ir/tensor.cc create mode 100644 mindspore/lite/src/ir/tensor.h create mode 100644 mindspore/lite/src/kernel_factory.cc create mode 100644 mindspore/lite/src/kernel_factory.h create mode 100644 mindspore/lite/src/kernel_registry.cc create mode 100644 mindspore/lite/src/kernel_registry.h create mode 100644 mindspore/lite/src/lite_kernel.cc create mode 100644 mindspore/lite/src/lite_kernel.h create mode 100644 mindspore/lite/src/lite_session.cc create mode 100644 mindspore/lite/src/lite_session.h create mode 100644 mindspore/lite/src/model.cc create mode 100644 mindspore/lite/src/model_impl.cc create mode 100644 mindspore/lite/src/model_impl.h create mode 100644 mindspore/lite/src/ops/CMakeLists.txt create mode 100644 mindspore/lite/src/ops/addn.cc create mode 100644 mindspore/lite/src/ops/argmax.cc create mode 100644 mindspore/lite/src/ops/argmin.cc create mode 100644 mindspore/lite/src/ops/arithmetic.cc create mode 100644 mindspore/lite/src/ops/arithmetic_self.cc create mode 100644 mindspore/lite/src/ops/batch_to_space.cc create mode 100644 mindspore/lite/src/ops/broadcast_to.cc create mode 100644 mindspore/lite/src/ops/cast.cc create mode 100644 mindspore/lite/src/ops/concat.cc create mode 100644 mindspore/lite/src/ops/conv.cc create mode 100644 mindspore/lite/src/ops/convolution_depthwise.cc create mode 100644 mindspore/lite/src/ops/crop.cc create mode 100644 mindspore/lite/src/ops/deconvolution.cc create mode 100644 mindspore/lite/src/ops/deconvolution_depthwise.cc create mode 100644 mindspore/lite/src/ops/depth_to_space.cc create mode 100644 mindspore/lite/src/ops/expand_dims.cc create mode 100644 mindspore/lite/src/ops/fill.cc create mode 100644 mindspore/lite/src/ops/flatten.cc create mode 100644 mindspore/lite/src/ops/fullconnection.cc create mode 100644 mindspore/lite/src/ops/gather.cc create mode 100644 mindspore/lite/src/ops/gather_nd.cc create mode 100644 mindspore/lite/src/ops/matmul.cc create mode 100644 mindspore/lite/src/ops/nchw2nhwc.cc create mode 100644 mindspore/lite/src/ops/nhwc2nchw.cc create mode 100644 mindspore/lite/src/ops/one_hot.cc create mode 100644 mindspore/lite/src/ops/ops.cc create mode 100644 mindspore/lite/src/ops/ops.h create mode 100644 mindspore/lite/src/ops/pad.cc create mode 100644 mindspore/lite/src/ops/pooling.cc create mode 100644 mindspore/lite/src/ops/range.cc create mode 100644 mindspore/lite/src/ops/rank.cc create mode 100644 mindspore/lite/src/ops/reduce.cc create mode 100644 mindspore/lite/src/ops/reshape.cc create mode 100644 mindspore/lite/src/ops/resize.cc create mode 100644 mindspore/lite/src/ops/reverse_sequence.cc create mode 100644 mindspore/lite/src/ops/scatter_nd.cc create mode 100644 mindspore/lite/src/ops/shape.cc create mode 100644 mindspore/lite/src/ops/slice.cc create mode 100644 mindspore/lite/src/ops/softmax.cc create mode 100644 mindspore/lite/src/ops/split.cc create mode 100644 mindspore/lite/src/ops/squeeze.cc create mode 100644 mindspore/lite/src/ops/stack.cc create mode 100644 mindspore/lite/src/ops/strided_slice.cc create mode 100644 mindspore/lite/src/ops/tile.cc create mode 100644 mindspore/lite/src/ops/topk.cc create mode 100644 mindspore/lite/src/ops/transpose.cc create mode 100644 mindspore/lite/src/ops/unique.cc create mode 100644 mindspore/lite/src/ops/unsqueeze.cc create mode 100644 mindspore/lite/src/ops/unstack.cc create mode 100644 mindspore/lite/src/ops/where.cc create mode 100644 mindspore/lite/src/ops/zeroslike.cc create mode 100644 mindspore/lite/src/param_value_lite.h create mode 100644 mindspore/lite/src/populate_parameter.cc create mode 100644 mindspore/lite/src/populate_parameter.h create mode 100644 mindspore/lite/src/runtime/allocator.cc create mode 100644 mindspore/lite/src/runtime/allocator.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/concat_base.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/layout_transform.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/layout_transform.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/matrix.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/matrix.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/pooling_base.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/base/reshape_base.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/activation.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/addn.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/addn.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/bias.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/bias.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/cast.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/concat.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/concat.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/convolution.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/crop.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/fill.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/fill.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/flatten.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/flatten.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/gather.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/pad.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/pooling.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/power.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/power.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/prelu.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/prelu.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/range.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/range.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/rank.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/rank.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/reduce.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/reshape.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/reshape.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/resize.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/resize.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/reverse.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/reverse.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/scale.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/shape.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/shape.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/slice.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/split.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/split.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/stack.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/stack.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/strided_slice.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/strided_slice.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/tile.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/tile.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/topk.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/transpose.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/unique.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/unique.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/unstack.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/unstack.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/where.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/where.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/CMakeLists.txt create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/add_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/add_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_common.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_common.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_parameter.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm32/IndirectGemmFp32_8x4.S create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm32/IndirectGemmInt16to32_8x4.S create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm32/IndirectGemmInt8_2x4.S create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/IndirectGemmFp16_16x8.S create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/IndirectGemmFp32_8x8.S create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/IndirectGemmInt16to32_8x4.S create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/IndirectGemmInt8_4x4.S create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/bias_add.S create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/bias_add_relu.S create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/bias_add_relu6.S create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matmul.s create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matrix_add.S create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matrix_sub.S create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/relu.S create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/relu6.S create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/assembly/opt/IndirectGemmInt8_24x4_dp.S create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/common_func.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/common_func.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/concat_parameter.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/conv_parameter.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/errorcode.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/flatten.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/flatten.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/activation.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arg_min_max.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arg_min_max.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arithmetic.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arithmetic.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arithmetic_self.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arithmetic_self.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/batch_to_space.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/batch_to_space.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/broadcast_to.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/broadcast_to.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/cast.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/cast.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/common_func.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/common_func.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/concat.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/concat.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv_depthwise.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv_depthwise.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/crop.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/crop.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/deconv.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/deconv.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/depth_to_space.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/depth_to_space.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/expandDims.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/expandDims.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/fill.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/fill.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/gather.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/gather.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/gatherNd.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/gatherNd.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/local_response_norm.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/local_response_norm.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/one_hot.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/one_hot.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/pooling.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/pooling.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/range.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/range.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/rank.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/rank.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reduce.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reduce.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reverse.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reverse.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/slice.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/slice.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/stack.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/stack.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/strassen_matmul.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/strassen_matmul.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/strided_slice.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/strided_slice.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/unsqueeze.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fp32/unsqueeze.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fused_batchnorm.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/fused_batchnorm.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/concat_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/concat_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_depthwise_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_depthwise_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/deconv.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/deconv.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/mul_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/mul_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/pooling_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/pooling_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/reshape_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/reshape_int8.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/matmul.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/matrix_table.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/mul_parameter.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/offset_utils.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/op_base.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/opclib_utils.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/opclib_utils.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/opt_op_handler.c create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/optimized_kernel.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/pack.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/pack.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/pad.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/pad.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/power.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/prelu.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/prelu.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/quantization/fixed_point.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/reshape.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/reshape.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/reshape_parameter.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/resize.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/resize.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/reverse_sequence.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/reverse_sequence.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/scale.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/scale.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/scatter_nd.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/scatter_nd.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/shape.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/sparse_to_dense.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/sparse_to_dense.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/split.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/split.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/squeeze.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/squeeze.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/strassen_matmul.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/tile.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/tile.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/topk.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/topk.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/transpose.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/transpose.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/unique.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/unique.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/unstack.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/unstack.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/where.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/where.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/winograd_transform.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/winograd_transform.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/winograd_utils.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/winograd_utils.h create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/zeroslike.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/zeroslike.h create mode 100644 mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/fp16/conv2d_transpose2x2.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/fp16/depthwise_conv2d.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/fp16/matmul.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/fp32/arithmetic.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/fp32/avg_pool2d.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/fp32/concat.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/fp32/conv2d_transpose2x2.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/fp32/convolution.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/fp32/depthwise_conv2d.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/fp32/matmul.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/fp32/max_pool2d.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/fp32/softmax.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/image_format.h create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.h create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.h create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.h create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h create mode 100644 mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc create mode 100644 mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h create mode 100644 mindspore/lite/src/runtime/kernel/opencl/utils.cc create mode 100644 mindspore/lite/src/runtime/kernel/opencl/utils.h create mode 100644 mindspore/lite/src/runtime/opencl/CMakeLists.txt create mode 100644 mindspore/lite/src/runtime/opencl/opencl_allocator.cc create mode 100644 mindspore/lite/src/runtime/opencl/opencl_allocator.h create mode 100644 mindspore/lite/src/runtime/opencl/opencl_executor.cc create mode 100644 mindspore/lite/src/runtime/opencl/opencl_executor.h create mode 100644 mindspore/lite/src/runtime/opencl/opencl_runtime.cc create mode 100644 mindspore/lite/src/runtime/opencl/opencl_runtime.h create mode 100644 mindspore/lite/src/runtime/opencl/opencl_wrapper.cc create mode 100644 mindspore/lite/src/runtime/opencl/opencl_wrapper.h create mode 100644 mindspore/lite/src/runtime/runtime_api.cc create mode 100644 mindspore/lite/src/runtime/runtime_api.h create mode 100644 mindspore/lite/src/runtime/thread_pool.cc create mode 100644 mindspore/lite/src/runtime/thread_pool.h create mode 100644 mindspore/lite/src/runtime/workspace_pool.cc create mode 100644 mindspore/lite/src/runtime/workspace_pool.h create mode 100644 mindspore/lite/src/scheduler.cc create mode 100644 mindspore/lite/src/scheduler.h create mode 100644 mindspore/lite/src/train/base_ref_utils.cc create mode 100644 mindspore/lite/src/train/base_ref_utils.h create mode 100644 mindspore/lite/src/train/import.hpp create mode 100644 mindspore/lite/src/train/lite_kernel_runtime.cc create mode 100644 mindspore/lite/src/train/lite_kernel_runtime.h create mode 100644 mindspore/lite/src/train/model_impl.cc create mode 100644 mindspore/lite/src/train/model_impl.h create mode 100644 mindspore/lite/src/train/train_session.cc create mode 100644 mindspore/lite/src/train/train_session.h create mode 100644 mindspore/lite/test/CMakeLists.txt create mode 100644 mindspore/lite/test/benchmark_test.cc create mode 100644 mindspore/lite/test/converter_test.cc create mode 100644 mindspore/lite/tools/benchmark/CMakeLists.txt create mode 100644 mindspore/lite/tools/benchmark/benchmark.cc create mode 100644 mindspore/lite/tools/benchmark/benchmark.h create mode 100644 mindspore/lite/tools/benchmark/main.cc create mode 100755 mindspore/lite/tools/common/CMakeLists.txt create mode 100644 mindspore/lite/tools/common/converter_op_utils.h create mode 100755 mindspore/lite/tools/common/flag_parser.cc create mode 100755 mindspore/lite/tools/common/flag_parser.h create mode 100755 mindspore/lite/tools/common/graph_util.cc create mode 100644 mindspore/lite/tools/common/graph_util.h create mode 100644 mindspore/lite/tools/common/node_util.cc create mode 100644 mindspore/lite/tools/common/node_util.h create mode 100644 mindspore/lite/tools/common/option.h create mode 100755 mindspore/lite/tools/common/storage.cc create mode 100644 mindspore/lite/tools/common/storage.h create mode 100644 mindspore/lite/tools/common/tensor_util.cc create mode 100644 mindspore/lite/tools/common/tensor_util.h create mode 100644 mindspore/lite/tools/converter/CMakeLists.txt create mode 100644 mindspore/lite/tools/converter/anf_transform.cc create mode 100644 mindspore/lite/tools/converter/anf_transform.h create mode 100644 mindspore/lite/tools/converter/converter.cc create mode 100644 mindspore/lite/tools/converter/converter.h create mode 100644 mindspore/lite/tools/converter/converter_flags.cc create mode 100644 mindspore/lite/tools/converter/converter_flags.h create mode 100644 mindspore/lite/tools/converter/graphdef_transform.cc create mode 100644 mindspore/lite/tools/converter/graphdef_transform.h create mode 100644 mindspore/lite/tools/converter/main.cc create mode 100644 mindspore/lite/tools/converter/model_parser.h create mode 100644 mindspore/lite/tools/converter/optimizer.cc create mode 100644 mindspore/lite/tools/converter/optimizer.h create mode 100755 mindspore/lite/tools/converter/optimizer/CMakeLists.txt create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/CMakeLists.txt create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/add_const_fold_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/add_const_fold_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/cast_const_fold_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/cast_const_fold_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/concat_v2_const_fold_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/concat_v2_const_fold_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/const_fold_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/const_fold_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/expand_dims_const_fold_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/expand_dims_const_fold_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/mul_const_fold_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/mul_const_fold_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/range_const_fold_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/range_const_fold_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/reshape_const_fold_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/reshape_const_fold_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/rsqrt_const_fold_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/rsqrt_const_fold_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/shape_const_fold_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/shape_const_fold_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/slice_const_fold_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/slice_const_fold_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/stack_const_fold_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/stack_const_fold_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/strided_slice_const_fold_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/strided_slice_const_fold_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/sub_const_fold_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/sub_const_fold_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/tile_const_fold_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/tile_const_fold_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/transpose_const_fold_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/const_fold/transpose_const_fold_pass.h create mode 100755 mindspore/lite/tools/converter/optimizer/fusion/CMakeLists.txt create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/batchnorm_fold_fusion_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/batchnorm_fold_fusion_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/conv_activation_fusion_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/conv_activation_fusion_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/conv_bn_fusion_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/conv_bn_fusion_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/conv_relu6_fusion_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/conv_relu6_fusion_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/conv_relu_fusion_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/conv_relu_fusion_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/conv_scale_bias_fusion_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/conv_scale_bias_fusion_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/conv_scale_fusion_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/conv_scale_fusion_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/format_trans_fusion_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/format_trans_fusion_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/fusion_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/fusion_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/fusion_pattern.cc create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/fusion_pattern.h create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/matmul_biasadd_fusion_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/matmul_biasadd_fusion_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/quant_cast_fusion_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/fusion/quant_cast_fusion_pass.h create mode 100755 mindspore/lite/tools/converter/optimizer/graph/CMakeLists.txt create mode 100644 mindspore/lite/tools/converter/optimizer/graph/format_trans_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/graph/format_trans_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/graph/isolated_node_remove_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/graph/isolated_node_remove_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/graph/model_input_format_preprocess_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/graph/model_input_format_preprocess_pass.h create mode 100644 mindspore/lite/tools/converter/optimizer/graph/topological_sort_pass.cc rename mindspore/{core/ir/lite/param_value_lite.h => lite/tools/converter/optimizer/graph/topological_sort_pass.h} (51%) create mode 100644 mindspore/lite/tools/converter/optimizer/graph/unused_node_remove_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/graph/unused_node_remove_pass.h create mode 100755 mindspore/lite/tools/converter/optimizer/node/CMakeLists.txt create mode 100644 mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc create mode 100644 mindspore/lite/tools/converter/optimizer/node/weight_format_pass.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt create mode 100755 mindspore/lite/tools/converter/parser/caffe/caffe.proto create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_converter.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_converter.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_inspector.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_inspector.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.h create mode 100755 mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_relu_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_relu_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_sigmoid_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_sigmoid_parser.h create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/CMakeLists.txt create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx.proto create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h create mode 100755 mindspore/lite/tools/converter/parser/onnx/onnx_converter.cc create mode 100755 mindspore/lite/tools/converter/parser/onnx/onnx_converter.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h create mode 100755 mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt create mode 100644 mindspore/lite/tools/converter/parser/tflite/schema.fbs create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_converter.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_converter.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_logistic_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_logistic_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_max_pooling_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_max_pooling_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_mean_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_mean_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_mean_pooling_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_mean_pooling_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_mul_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_mul_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_relu6_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_relu6_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_resize_bilinear_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_resize_bilinear_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_rsqrt_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_rsqrt_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_squareddifference_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_squareddifference_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_sub_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_sub_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_tanh_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_tanh_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_util.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_util.h create mode 100644 mindspore/lite/tools/converter/quantizer/CMakeLists.txt create mode 100644 mindspore/lite/tools/converter/quantizer/general_bitpacking.cc create mode 100644 mindspore/lite/tools/converter/quantizer/general_bitpacking.h create mode 100644 mindspore/lite/tools/converter/quantizer/post_training.cc create mode 100644 mindspore/lite/tools/converter/quantizer/post_training.h create mode 100644 mindspore/lite/tools/converter/quantizer/quantize_util.cc create mode 100644 mindspore/lite/tools/converter/quantizer/quantize_util.h create mode 100644 mindspore/lite/tools/converter/quantizer/quantizer.cc create mode 100644 mindspore/lite/tools/converter/quantizer/quantizer.h create mode 100644 mindspore/lite/tools/converter/quantizer/weight_quantizer.cc create mode 100644 mindspore/lite/tools/converter/quantizer/weight_quantizer.h diff --git a/.gitignore b/.gitignore index 057169ec420..22ca82834cc 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,20 @@ mindspore/lib output *.ir +# flatbuffer +mindspore/lite/tools/converter/parser/tflite/schema_generated.h +mindspore/lite/tools/converter/parser/caffe/caffe.pb.cc +mindspore/lite/tools/converter/parser/caffe/caffe.pb.h +mindspore/lite/tools/converter/parser/onnx/onnx.pb.h +mindspore/lite/tools/converter/parser/onnx/onnx.pb.h +mindspore/lite/tools/converter/schema/*.h +mindspore/lite/tools/converter/schema/inner +mindspore/lite/schema/*.h +mindspore/lite/schema/inner + +mindspore/lite/src/runtime/kernel/opencl/cl/fp16/*.inc +mindspore/lite/src/runtime/kernel/opencl/cl/fp32/*.inc + # Cmake files CMakeFiles/ cmake_install.cmake @@ -71,5 +85,3 @@ test_temp_summary_event_file/ mindspore/version.py mindspore/default_config.py mindspore/.commit_id -onnx.proto -mindspore/ccsrc/onnx.proto diff --git a/.gitmodules b/.gitmodules index c553d137c6e..80eac2de7dc 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,7 @@ [submodule "third_party/flatbuffers"] path = third_party/flatbuffers url = https://github.com/google/flatbuffers.git + ignore = all [submodule "third_party/googletest"] path = third_party/googletest url = https://github.com/google/googletest.git @@ -10,9 +11,16 @@ [submodule "third_party/protobuf"] path = third_party/protobuf url = https://github.com/protocolbuffers/protobuf.git + ignore = all [submodule "akg"] path = akg url = https://gitee.com/mindspore/akg.git [submodule "graphengine"] path = graphengine url = https://gitee.com/mindspore/graphengine.git +[submodule "third_party/OpenCL-CLHPP"] + path = third_party/OpenCL-CLHPP + url = https://github.com/KhronosGroup/OpenCL-CLHPP.git +[submodule "third_party/OpenCL-Headers"] + path = third_party/OpenCL-Headers + url = https://github.com/KhronosGroup/OpenCL-Headers.git diff --git a/build.sh b/build.sh index 146b0de1c51..4ebcad069c8 100755 --- a/build.sh +++ b/build.sh @@ -25,7 +25,8 @@ usage() echo "Usage:" echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge] [-m infer|train] \\" echo " [-a on|off] [-Q on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\" - echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I] [-K] [-B on|off] [-w on|off] [-E] [-l on|off]" + echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I arm64|arm32|x86_64] [-K] \\" + echo " [-B on|off] [-w on|off] [-E] [-l on|off]" echo "" echo "Options:" echo " -d Debug mode" @@ -51,7 +52,7 @@ usage() echo " -z Compile dataset & mindrecord, default on" echo " -M Enable MPI and NCCL for GPU training, gpu default on" echo " -V Specify the minimum required cuda version, default CUDA 10.1" - echo " -I Compile predict, default off" + echo " -I Compile lite" echo " -K Compile with AKG, default on" echo " -s Enable serving module, default off" echo " -w Enable acl module, default off" @@ -93,9 +94,10 @@ checkopts() COMPILE_MINDDATA="on" ENABLE_MPI="off" CUDA_VERSION="10.1" - COMPILE_PREDICT="off" + COMPILE_LITE="off" + LITE_PLATFORM="" + SUPPORT_TRAIN="off" USE_GLOG="on" - PREDICT_PLATFORM="" ENABLE_AKG="on" ENABLE_SERVING="off" ENABLE_ACL="off" @@ -240,13 +242,16 @@ checkopts() fi ;; I) - COMPILE_PREDICT="on" + COMPILE_LITE="on" if [[ "$OPTARG" == "arm64" ]]; then - PREDICT_PLATFORM="arm64" + LITE_PLATFORM="arm64" + elif [[ "$OPTARG" == "arm32" ]]; then + LITE_PLATFORM="arm32" elif [[ "$OPTARG" == "x86_64" ]]; then - PREDICT_PLATFORM="x86_64" + ENABLE_CONVERTER="on" + LITE_PLATFORM="x86_64" else - echo "-I parameter must be arm64 or x86_64" + echo "-I parameter must be arm64、arm32 or x86_64" exit 1 fi ;; @@ -382,128 +387,247 @@ build_mindspore() echo "success to build mindspore project!" } -build_predict() -{ - git submodule update --init --recursive third_party/incubator-tvm - echo "start build predict project" - - git submodule update --init --recursive third_party/flatbuffers - git submodule update --init --recursive third_party/googletest - git submodule update --init --recursive third_party/protobuf - - rm -rf "${BASEPATH}/predict/build" - mkdir -pv "${BASEPATH}/predict/build" - rm -rf "${BASEPATH}/predict/output" - mkdir -pv "${BASEPATH}/predict/output" - - if [[ "$PREDICT_PLATFORM" == "arm64" ]]; then - if [ "${ANDROID_NDK}" ]; then - echo -e "\e[31mANDROID_NDK_PATH=$ANDROID_NDK \e[0m" - else - echo -e "\e[31mplease set ANDROID_NDK_PATH in environment variable for example: export ANDROID_NDK=/root/usr/android-ndk-r16b/ \e[0m" - exit 1 - fi - fi - - #build flatbuf - cd "${BASEPATH}/third_party/flatbuffers" - rm -rf build && mkdir -p build && cd build && cmake .. && make -j$THREAD_NUM - FLATC="${BASEPATH}"/third_party/flatbuffers/build/flatc - cd "${BASEPATH}"/predict/schema && mkdir -p "${BASEPATH}"/predict/schema/inner - find . -name "*.fbs" -print0 | xargs -0 "${FLATC}" -c -b - find . -name "*.fbs" -print0 | xargs -0 "${FLATC}" -c -b --reflect-types --gen-mutable --reflect-names --gen-object-api -o ${BASEPATH}/predict/schema/inner - - # check LLVM_PATH - if [ "${LLVM_PATH}" == "" ]; then - echo "Please set LLVM_PATH in env for example export LLVM_PATH=/xxxx/bin/llvm-config" - exit - fi - - #build tvm - tvm_open_source="${BASEPATH}/third_party/incubator-tvm" - tvm_kernel_build="${BASEPATH}/predict/module/tvm_kernel" - if [ ! -f "${tvm_kernel_build}"/incubator-tvm/build/libtvm.so ]; then - rm -fr "${tvm_kernel_build}"/incubator-tvm - cp -fr "${tvm_open_source}" "${tvm_kernel_build}" - mkdir -p "${tvm_kernel_build}"/incubator-tvm/build - patch -d "${tvm_kernel_build}"/incubator-tvm -p1 < "${BASEPATH}"/third_party/patch/predict/0001-RetBugFix-CustomRuntime_v06.patch - cp "${tvm_kernel_build}"/lite/src/codegen/llvm/lite_rtfunc_reset.cc "${tvm_kernel_build}"/incubator-tvm/src/codegen/llvm/ - cp "${tvm_open_source}"/cmake/config.cmake "${tvm_kernel_build}"/incubator-tvm - if [ "${LLVM_PATH}" ]; then - sed -i "s#set(USE_LLVM .*)#set(USE_LLVM \"${LLVM_PATH}\")#g" "${tvm_kernel_build}"/incubator-tvm/config.cmake - else - echo "need set LLVM_PATH in env for example export LLVM_PATH=/xxxx/bin/llvm-config" - fi - cd "${tvm_kernel_build}"/incubator-tvm/build - cmake .. - make -j$THREAD_NUM +checkndk() { + if [ "${ANDROID_NDK}" ]; then + echo -e "\e[31mANDROID_NDK_PATH=$ANDROID_NDK \e[0m" else - cd "${tvm_kernel_build}"/incubator-tvm/build - make -j$THREAD_NUM + echo -e "\e[31mplease set ANDROID_NDK_PATH in environment variable for example: export ANDROID_NDK=/root/usr/android-ndk-r20b/ \e[0m" + exit 1 fi - - #gen op - predict_tvm_op_lib_path="${BASEPATH}/predict/module/tvm_kernel/build/lib_x86" - predict_platform="x86" - if [[ "$PREDICT_PLATFORM" == "arm64" ]]; then - predict_tvm_op_lib_path="${BASEPATH}/predict/module/tvm_kernel/build/lib_arm64" - predict_platform="arm64" - fi - - need_get_libs=true - if [ -d "${predict_tvm_op_lib_path}" ]; then - file_list=$(ls "${predict_tvm_op_lib_path}") - if [ -n "${file_list}" ]; then - libstime=$(stat -c %Y "${predict_tvm_op_lib_path}"/* | sort -u | tail -n1) - pythontime=$(find "${BASEPATH}"/predict/module/tvm_kernel/lite/python/ -name "*.py" -exec stat -c %Y {} \; | - sort -u | tail -n1) - if [ "${libstime}" -ge "${pythontime}" ]; then - need_get_libs=false - else - rm -fr "${predict_tvm_op_lib_path}" - fi - fi - fi - - if $need_get_libs; then - PYTHONPATH_OLD=${PYTHONPATH} - export PYTHONPATH="${tvm_kernel_build}/incubator-tvm/python:${tvm_kernel_build}/incubator-tvm/topi/python:${tvm_kernel_build}/incubator-tvm/nnvm/python:${tvm_kernel_build}/lite/python:" - cd "${BASEPATH}"/predict/module/tvm_kernel/lite/python/at_ops - python3 at_gen_strip.py ${predict_platform} - export PYTHONPATH=${PYTHONPATH_OLD} - fi - - cd "${BASEPATH}/predict/build" - if [[ "$PREDICT_PLATFORM" == "arm64" ]]; then - cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \ - -DANDROID_NATIVE_API_LEVEL=android-19 -DANDROID_NDK="${ANDROID_NDK}" \ - -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" -DANDROID_STL="c++_shared" \ - -DANDROID_ABI="arm64-v8a" -DENABLE_PREDICT_ARM64=ON -DANDROID_ALLOW_UNDEFINED_SYMBOLS=TRUE .. - elif [[ "$PREDICT_PLATFORM" == "x86_64" ]]; then - cmake .. - fi - - make ${VERBOSE} -j$THREAD_NUM - if [[ "$PREDICT_PLATFORM" == "x86_64" ]]; then - cd "${BASEPATH}/predict/build/test" && ./run_tests.sh - fi - - # copy securec include files - mkdir -p "${BASEPATH}/predict/output/include/securec/include" - cp "${BASEPATH}"/third_party/securec/include/* "${BASEPATH}"/predict/output/include/securec/include - - cd "${BASEPATH}/predict/output/" - if [[ "$PREDICT_PLATFORM" == "x86_64" ]]; then - tar -cf MSPredict-0.5.0-linux_x86_64.tar.gz include/ lib/ --warning=no-file-changed - elif [[ "$PREDICT_PLATFORM" == "arm64" ]]; then - tar -cf MSPredict-0.5.0-linux_aarch64.tar.gz include/ lib/ --warning=no-file-changed - fi - echo "success to build predict project!" } -if [[ "X$COMPILE_PREDICT" = "Xon" ]]; then - build_predict - echo "---------------- mindspore: build end ----------------" +gene_flatbuffer() { + FLAT_DIR="${BASEPATH}/mindspore/lite/schema" + cd ${FLAT_DIR} && rm -rf "${FLAT_DIR}/inner" && mkdir -p "${FLAT_DIR}/inner" + find . -name "*.fbs" -print0 | xargs -0 "${FLATC}" -c -b + find . -name "*.fbs" -print0 | xargs -0 "${FLATC}" -c -b --reflect-types --gen-mutable --reflect-names --gen-object-api -o "${FLAT_DIR}/inner" + + FLAT_DIR="${BASEPATH}/mindspore/lite/tools/converter/parser/tflite" + cd ${FLAT_DIR} + find . -name "*.fbs" -print0 | xargs -0 "${FLATC}" -c -b --reflect-types --gen-mutable --reflect-names --gen-object-api -o "${FLAT_DIR}/" +} + +build_flatbuffer() { + cd ${BASEPATH} + FLATC="${BASEPATH}"/third_party/flatbuffers/build/flatc + if [[ ! -f "${FLATC}" ]]; then + git submodule update --init --recursive third_party/flatbuffers + cd ${BASEPATH}/third_party/flatbuffers + rm -rf build && mkdir -pv build && cd build && cmake .. && make -j$THREAD_NUM + gene_flatbuffer + fi + if [[ "${INC_BUILD}" == "off" ]]; then + gene_flatbuffer + fi +} + +gene_protobuf() { + PROTO_SRC_DIR="${BASEPATH}/mindspore/lite/tools/converter/parser/caffe" + find ${PROTO_SRC_DIR} -name "*.proto" -print0 | xargs -0 "${PROTOC}" -I"${PROTO_SRC_DIR}" --cpp_out="${PROTO_SRC_DIR}" + PROTO_SRC_DIR="${BASEPATH}/mindspore/lite/tools/converter/parser/onnx" + find ${PROTO_SRC_DIR} -name "*.proto" -print0 | xargs -0 "${PROTOC}" -I"${PROTO_SRC_DIR}" --cpp_out="${PROTO_SRC_DIR}" +} + +build_protobuf() { + cd ${BASEPATH} + PROTOC="${BASEPATH}"/third_party/protobuf/build/bin/protoc + if [[ ! -f "${PROTOC}" ]]; then + git submodule update --init --recursive third_party/protobuf + cd ${BASEPATH}/third_party/protobuf + rm -rf build && mkdir -pv build && ./autogen.sh + ./configure --prefix=${BASEPATH}/third_party/protobuf/build + make clean && make -j$THREAD_NUM && make install + gene_protobuf + fi + if [[ "${INC_BUILD}" == "off" ]]; then + gene_protobuf + fi +} + +build_gtest() { + cd ${BASEPATH} + git submodule update --init --recursive third_party/googletest +} + +gene_clhpp() { + CL_SRC_DIR="${BASEPATH}/mindspore/lite/src/runtime/kernel/opencl/cl" + for sub_dir in "${CL_SRC_DIR}"/* + do + data_type="$(basename ${sub_dir})" + if [ ! -d ${CL_SRC_DIR}/${data_type} ]; then + continue + fi + cd ${CL_SRC_DIR}/${data_type} + rm -rf *.inc + echo "$(cd "$(dirname $0)"; pwd)" + for file_path in "${CL_SRC_DIR}/${data_type}"/* + do + file="$(basename ${file_path})" + inc_file=`echo ${CL_SRC_DIR}/${data_type}/${file} | sed 's/$/.inc/'` + sed 's/^/\"/;s/$/ \\n\" \\/' ${CL_SRC_DIR}/${data_type}/${file} > ${inc_file} + kernel_name=`echo ${file} | sed s'/.\{3\}$//'` + sed -i "1i\static const char *${kernel_name}_source_${data_type} =\"\\n\" \\" ${inc_file} + sed -i '$a\;' ${inc_file} + done + done +} + +gene_ocl_program() { + CL_SRC_DIR="${BASEPATH}/mindspore/lite/src/runtime/kernel/opencl/cl" + SPIRV_DIR=build/spirv + rm -rf ${SPIRV_DIR} + mkdir -pv ${SPIRV_DIR} + for sub_dir in "${CL_SRC_DIR}"/* + do + data_type="$(basename ${sub_dir})" + if [ ! -d ${CL_SRC_DIR}/${data_type} ]; then + continue + fi + #echo $(cd "$(dirname $0)"; pwd) + for file_path in "${CL_SRC_DIR}/${data_type}"/* + do + file="$(basename ${file_path})" + if [ "${file##*.}" != "cl" ]; then + continue + fi + clang -Xclang -finclude-default-header -cl-std=CL2.0 --target=spir64-unknown-unknown -emit-llvm \ + -c -O0 -o ${SPIRV_DIR}/${file%.*}.bc ${CL_SRC_DIR}/${data_type}/${file} + done + done + + bcs=`ls ${SPIRV_DIR}/*.bc` + llvm-link ${bcs} -o ${SPIRV_DIR}/program.bc + llvm-spirv -o ${SPIRV_DIR}/program.spv ${SPIRV_DIR}/program.bc + + CL_PROGRAM_PATH="${BASEPATH}/mindspore/lite/src/runtime/kernel/opencl/cl/program.inc" + echo "#include " > ${CL_PROGRAM_PATH} + echo "std::vector g_program_binary = {" >> ${CL_PROGRAM_PATH} + #hexdump -v -e '16/1 "0x%02x, " "\n"' ${SPIRV_DIR}/program.spv >> ${CL_PROGRAM_PATH} + hexdump -v -e '1/1 "0x%02x, "' ${SPIRV_DIR}/program.spv >> ${CL_PROGRAM_PATH} + echo "};" >> ${CL_PROGRAM_PATH} + echo "Compile SPIRV done" +} + +build_opencl() { + cd ${BASEPATH} + git submodule update --init third_party/OpenCL-Headers + git submodule update --init third_party/OpenCL-CLHPP + if [[ "${OPENCL_OFFLINE_COMPILE}" == "on" ]]; then + gene_ocl_program + else + gene_clhpp + fi +} + +build_lite() +{ + echo "start build mindspore lite project" + + if [[ "${ENABLE_GPU}" == "on" ]]; then + build_opencl + fi + if [[ "${LITE_PLATFORM}" == "x86_64" ]]; then + build_protobuf + fi + build_flatbuffer + build_gtest + + cd "${BASEPATH}/mindspore/lite" + mkdir -pv build + cd build + BUILD_TYPE="Release" + if [[ "${DEBUG_MODE}" == "on" ]]; then + BUILD_TYPE="Debug" + fi + + if [[ "${LITE_PLATFORM}" == "arm64" ]]; then + checkndk + cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \ + -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="arm64-v8a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \ + -DANDROID_STL="c++_shared" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ + -DBUILD_DEVICE=on -DPLATFORM_ARM64=on -DBUILD_CONVERTER=off -DENABLE_NEON=on -DENABLE_FP16="off" \ + -DSUPPORT_GPU=${ENABLE_GPU} -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} "${BASEPATH}/mindspore/lite" + elif [[ "${LITE_PLATFORM}" == "arm32" ]]; then + checkndk + cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \ + -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="clang" \ + -DANDROID_STL="c++_shared" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ + -DBUILD_DEVICE=on -DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} -DBUILD_CONVERTER=off \ + -DSUPPORT_GPU=${ENABLE_GPU} -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} "${BASEPATH}/mindspore/lite" + else + cmake -DBUILD_DEVICE=on -DPLATFORM_ARM64=off -DBUILD_CONVERTER=${ENABLE_CONVERTER} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ + -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_GPU=${ENABLE_GPU} -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} "${BASEPATH}/mindspore/lite" + fi + VERBOSE=2 make -j$THREAD_NUM + COMPILE_RET=$? + + if [[ "${COMPILE_RET}" -ne 0 ]]; then + echo "---------------- mindspore lite: build failed ----------------" + else + mkdir -pv ${BASEPATH}/mindspore/lite/output/ + if [[ "$LITE_PLATFORM" == "x86_64" ]]; then + OUTPUT_DIR=${BASEPATH}/mindspore/lite/output/MSLite-0.5.0-linux_x86_64 + rm -rf ${OUTPUT_DIR} && mkdir -p ${OUTPUT_DIR} && cd ${OUTPUT_DIR} + mkdir -p ${OUTPUT_DIR}/converter && mkdir -p ${OUTPUT_DIR}/time_profile + mkdir -p ${OUTPUT_DIR}/benchmark && mkdir -p ${OUTPUT_DIR}/include && mkdir -p ${OUTPUT_DIR}/lib + mkdir -p ${OUTPUT_DIR}/third_party + cp ${BASEPATH}/mindspore/lite/build/tools/converter/converter_lite ${OUTPUT_DIR}/converter/ + cp ${BASEPATH}/mindspore/lite/build/tools/benchmark/benchmark ${OUTPUT_DIR}/benchmark/ + cp ${BASEPATH}/mindspore/lite/include/*.h ${OUTPUT_DIR}/include/ + mkdir -p ${OUTPUT_DIR}/include/ir/dtype/ + cp ${BASEPATH}/mindspore/core/ir/dtype/type_id.h ${OUTPUT_DIR}/include/ir/dtype/ + mkdir -p ${OUTPUT_DIR}/include/schema/ + cp ${BASEPATH}/mindspore/lite/schema/*.h ${OUTPUT_DIR}/include/schema/ + cp ${BASEPATH}/mindspore/lite/build/src/libmindspore-lite.so ${OUTPUT_DIR}/lib/ + mkdir -p ${OUTPUT_DIR}/third_party/protobuf/lib + cp -r ${BASEPATH}/third_party/protobuf/build/include/ ${OUTPUT_DIR}/third_party/protobuf/ + cp -r ${BASEPATH}/third_party/protobuf/build/lib/libprotobuf.so.19 ${OUTPUT_DIR}/third_party/protobuf/lib/ + cp -r ${BASEPATH}/third_party/protobuf/build/lib/libprotobuf.so.19.0.0 ${OUTPUT_DIR}/third_party/protobuf/lib/ + mkdir -p ${OUTPUT_DIR}/third_party/flatbuffers + cp -r ${BASEPATH}/third_party/flatbuffers/include/ ${OUTPUT_DIR}/third_party/flatbuffers/ + cd .. + tar -cf MSLite-0.5.0-linux_x86_64.tar.gz MSLite-0.5.0-linux_x86_64/ --warning=no-file-changed + elif [[ "$LITE_PLATFORM" == "arm64" ]]; then + OUTPUT_DIR=${BASEPATH}/mindspore/lite/output/MSLite-0.5.0-linux_arm64 + rm -rf ${OUTPUT_DIR} && mkdir -p ${OUTPUT_DIR} && cd ${OUTPUT_DIR} + mkdir -p ${OUTPUT_DIR}/time_profile && mkdir -p ${OUTPUT_DIR}/benchmark + mkdir -p ${OUTPUT_DIR}/include && mkdir -p ${OUTPUT_DIR}/lib + mkdir -p ${OUTPUT_DIR}/third_party + cp ${BASEPATH}/mindspore/lite/build/tools/benchmark/benchmark ${OUTPUT_DIR}/benchmark/ + cp ${BASEPATH}/mindspore/lite/include/*.h ${OUTPUT_DIR}/include/ + mkdir -p ${OUTPUT_DIR}/include/ir/dtype/ + cp ${BASEPATH}/mindspore/core/ir/dtype/type_id.h ${OUTPUT_DIR}/include/ir/dtype/ + mkdir -p ${OUTPUT_DIR}/include/schema/ + cp ${BASEPATH}/mindspore/lite/schema/*.h ${OUTPUT_DIR}/include/schema/ + cp ${BASEPATH}/mindspore/lite/build/src/libmindspore-lite.so ${OUTPUT_DIR}/lib/ + mkdir -p ${OUTPUT_DIR}/third_party/flatbuffers + cp -r ${BASEPATH}/third_party/flatbuffers/include/ ${OUTPUT_DIR}/third_party/flatbuffers/ + cd .. + tar -cf MSLite-0.5.0-linux_arm64.tar.gz MSLite-0.5.0-linux_arm64/ --warning=no-file-changed + elif [[ "$LITE_PLATFORM" == "arm32" ]]; then + OUTPUT_DIR=${BASEPATH}/mindspore/lite/output/MSLite-0.5.0-linux_arm32 + rm -rf ${OUTPUT_DIR} && mkdir -p ${OUTPUT_DIR} && cd ${OUTPUT_DIR} + mkdir -p ${OUTPUT_DIR}/time_profile && mkdir -p ${OUTPUT_DIR}/benchmark + mkdir -p ${OUTPUT_DIR}/include && mkdir -p ${OUTPUT_DIR}/lib + mkdir -p ${OUTPUT_DIR}/third_party + cp ${BASEPATH}/mindspore/lite/build/tools/benchmark/benchmark ${OUTPUT_DIR}/benchmark/ + cp ${BASEPATH}/mindspore/lite/include/*.h ${OUTPUT_DIR}/include/ + mkdir -p ${OUTPUT_DIR}/include/ir/dtype/ + cp ${BASEPATH}/mindspore/core/ir/dtype/type_id.h ${OUTPUT_DIR}/include/ir/dtype/ + mkdir -p ${OUTPUT_DIR}/include/schema/ + cp ${BASEPATH}/mindspore/lite/schema/*.h ${OUTPUT_DIR}/include/schema/ + cp ${BASEPATH}/mindspore/lite/build/src/libmindspore-lite.so ${OUTPUT_DIR}/lib/ + mkdir -p ${OUTPUT_DIR}/third_party/flatbuffers + cp -r ${BASEPATH}/third_party/flatbuffers/include/ ${OUTPUT_DIR}/third_party/flatbuffers/ + cd .. + tar -cf MSLite-0.5.0-linux_arm32.tar.gz MSLite-0.5.0-linux_arm32/ --warning=no-file-changed + fi + echo "---------------- mindspore lite: build success ----------------" + fi +} + +if [[ "X$COMPILE_LITE" = "Xon" ]]; then + build_lite exit else build_mindspore diff --git a/mindspore/core/ir/CMakeLists.txt b/mindspore/core/ir/CMakeLists.txt index 2a0b81ae047..77bc1b7661a 100644 --- a/mindspore/core/ir/CMakeLists.txt +++ b/mindspore/core/ir/CMakeLists.txt @@ -1,7 +1,3 @@ file(GLOB_RECURSE _IR_SRC_LIST ./*.cc dtype/*.cc) -file(GLOB_RECURSE _IR_LITE_SRC_FILES - ./lite/tensor.cc - ) -list(REMOVE_ITEM _IR_SRC_LIST ${_IR_LITE_SRC_FILES}) set_property(SOURCE ${_IR_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_IR) add_library(_mindspore_ir_obj OBJECT ${_IR_SRC_LIST}) diff --git a/mindspore/core/ir/lite/tensor.cc b/mindspore/core/ir/lite/tensor.cc deleted file mode 100644 index 9c3921eaddd..00000000000 --- a/mindspore/core/ir/lite/tensor.cc +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "ir/lite/tensor.h" -#include "securec/include/securec.h" - -namespace mindspore { -namespace tensor { -#define kMaxMallocSize 1024 * 1024 * 100 -Tensor::Tensor(const TypeId data_type, const std::vector &shape) : MetaTensor(data_type, shape) {} - -Tensor::Tensor(const TypePtr &type_ptr, const std::vector &shape) : MetaTensor(type_ptr, shape) {} - -Tensor::Tensor(const Tensor &tensor) : MetaTensor(tensor) { - this->data_type_ = tensor.data_type_; - this->shape_ = tensor.shape_; - auto ret = CopyTensorData(tensor); - if (0 != ret) { - MS_LOG(EXCEPTION) << "CopyTensorData error"; - } -} - -int Tensor::CopyTensorData(const Tensor &srcTensor) { - if (srcTensor.data_ == nullptr) { - MS_LOG(ERROR) << "data of srcTensor is nullptr"; - return -1; - } - size_t data_size = this->Size(); - MS_ASSERT(data_size == tensor.Size()); - if (this->data_ == nullptr) { - if (data_size > kMaxMallocSize) { - MS_LOG(ERROR) << "Malloc size is too big while coping data, " << data_size << " bytes"; - return -1; - } - this->data_ = malloc(data_size); - } - memcpy_s(this->data_, data_size, tensor.data_, tensor.Size()); - return 0; -} - -Tensor::~Tensor() { - if (nullptr != this->data_) { - free(this->data_); - } -} - -Tensor &Tensor::operator=(const Tensor &tensor) { - if (&tensor == this) { - return *this; - } - this->shape_ = tensor.shape_; - this->data_type_ = tensor.data_type_; - auto ret = CopyTensorData(tensor); - if (0 != ret) { - MS_LOG(EXCEPTION) << "CopyTensorData error"; - } - return *this; -} - -bool Tensor::operator==(const Tensor &tensor) { - return data_ == tensor.data_ && shape_ == tensor.shape_ && data_type_ == tensor.data_type_; -} - -bool Tensor::operator==(const Value &other) const { - if (other.isa()) { - auto other_ = static_cast(other); - return *this == other_; - } else { - return false; - } -} -} // namespace tensor -} // namespace mindspore diff --git a/mindspore/core/ir/lite/tensor.h b/mindspore/core/ir/lite/tensor.h deleted file mode 100644 index 7644f7d84fd..00000000000 --- a/mindspore/core/ir/lite/tensor.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CORE_IR_LITE_TENSOR_H_ -#define MINDSPORE_CORE_IR_LITE_TENSOR_H_ - -#include -#include -#include "ir/meta_tensor.h" -#include "ir/dtype/type.h" - -namespace mindspore { -namespace tensor { -class Tensor : public MetaTensor { - public: - Tensor() : MetaTensor() {} - - Tensor(const TypeId data_type, const std::vector &shape); - - Tensor(const TypePtr &type_ptr, const std::vector &shape); - - Tensor(const Tensor &tensor); - - ~Tensor(); - - int CopyTensorData(const Tensor &srcTensor); - - MS_DECLARE_PARENT(Tensor, MetaTensor) - - virtual Tensor &operator=(const Tensor &tensor); - - virtual bool operator==(const Tensor &tensor); - - bool operator==(const Value &other) const override; - - size_t Size() const { return MetaTensor::ElementsNum() * GetTypeByte(TypeIdToType(this->data_type_)); } - - void *Data() const { return data_; } - - protected: - void *data_; -}; - -using TensorPtr = std::shared_ptr; -} // namespace tensor -} // namespace mindspore - -#endif // MINDSPORE_CORE_IR_LITE_TENSOR_H_ diff --git a/mindspore/core/ir/meta_tensor.cc b/mindspore/core/ir/meta_tensor.cc index c0b6b79a64e..41b069b770e 100644 --- a/mindspore/core/ir/meta_tensor.cc +++ b/mindspore/core/ir/meta_tensor.cc @@ -75,8 +75,6 @@ int MetaTensor::ElementsNum() const { return std::accumulate(shape_.begin(), shape_.end(), 1LL, std::multiplies()); } -TypePtr MetaTensor::Dtype() const { return TypeIdToType(data_type_); } - TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) { if (type_ptr == nullptr) { MS_LOG(ERROR) << "Dtype to be set is nullptr."; diff --git a/mindspore/core/ir/meta_tensor_extends.cc b/mindspore/core/ir/meta_tensor_extends.cc index d73aa193742..53fc58eb784 100644 --- a/mindspore/core/ir/meta_tensor_extends.cc +++ b/mindspore/core/ir/meta_tensor_extends.cc @@ -37,5 +37,7 @@ abstract::AbstractBasePtr MetaTensor::ToAbstract() { abs_tensor->set_value(shared_from_base()); return abs_tensor; } + +TypePtr MetaTensor::Dtype() const { return TypeIdToType(data_type_); } } // namespace tensor } // namespace mindspore diff --git a/mindspore/core/ir/param_value.h b/mindspore/core/ir/param_value.h index 89730b02a42..d9976ccde22 100644 --- a/mindspore/core/ir/param_value.h +++ b/mindspore/core/ir/param_value.h @@ -31,7 +31,7 @@ class ParamValue { ParamValue(const ParamValue &other) = default; - ~ParamValue() = default; + virtual ~ParamValue() = default; tensor::MetaTensorPtr value() const { return value_; } void set_value(const tensor::MetaTensorPtr &value) { value_ = value; } diff --git a/mindspore/core/utils/log_adapter.cc b/mindspore/core/utils/log_adapter.cc index 175e790c35d..3a3f0e296ce 100644 --- a/mindspore/core/utils/log_adapter.cc +++ b/mindspore/core/utils/log_adapter.cc @@ -17,11 +17,15 @@ #include "utils/log_adapter.h" #include +#include #include +#ifndef USE_ANDROID_LOG #include "debug/trace.h" +#endif // namespace to support utils module definition namespace mindspore { +#ifndef USE_ANDROID_LOG #ifdef USE_GLOG static std::string GetTime() { #define BUFLEN 80 @@ -125,6 +129,7 @@ static int GetSlogLevel(MsLogLevel level) { } } #endif +#endif static std::string ExceptionTypeToString(ExceptionType type) { #define _TO_STRING(x) #x @@ -184,7 +189,24 @@ static const char *GetSubModuleName(SubModuleId module_id) { return sub_module_names[module_id % NUM_SUBMODUES]; } +const char *EnumStrForMsLogLevel(MsLogLevel level) { + if (level == DEBUG) { + return "DEBUG"; + } else if (level == INFO) { + return "INFO"; + } else if (level == WARNING) { + return "WARNING"; + } else if (level == ERROR) { + return "ERROR"; + } else if (level == EXCEPTION) { + return "EXCEPTION"; + } else { + return "NO_LEVEL"; + } +} + void LogWriter::OutputLog(const std::ostringstream &msg) const { +#ifndef USE_ANDROID_LOG #ifdef USE_GLOG auto submodule_name = GetSubModuleName(submodule_); google::LogMessage("", 0, GetGlogLevel(log_level_)).stream() @@ -197,6 +219,10 @@ void LogWriter::OutputLog(const std::ostringstream &msg) const { Dlog(static_cast(slog_module_id), GetSlogLevel(log_level_), "[%s:%d] %s] %s", location_.file_, location_.line_, location_.func_, str_msg.c_str()); #endif +#else + printf("%s [%s:%d] %s] %s\n:", EnumStrForMsLogLevel(log_level_), location_.file_, location_.line_, location_.func_, + msg.str().c_str()); +#endif } void LogWriter::operator<(const LogStream &stream) const noexcept { @@ -218,8 +244,10 @@ void LogWriter::operator^(const LogStream &stream) const { } oss << msg.str(); +#ifndef USE_ANDROID_LOG trace::TraceGraphEval(); trace::GetEvalStackInfo(oss); +#endif if (exception_handler_ != nullptr) { exception_handler_(exception_type_, oss.str()); diff --git a/mindspore/core/utils/log_adapter.h b/mindspore/core/utils/log_adapter.h index 866b61b864c..96906fbcdab 100644 --- a/mindspore/core/utils/log_adapter.h +++ b/mindspore/core/utils/log_adapter.h @@ -25,11 +25,13 @@ #include #include "utils/overload.h" #include "./securec.h" +#ifndef USE_ANDROID_LOG #ifdef USE_GLOG #include "glog/logging.h" #else #include "toolchain/slog.h" #endif +#endif // NOTICE: when relative path of 'log_adapter.h' changed, macro 'LOG_HDR_FILE_REL_PATH' must be changed #define LOG_HDR_FILE_REL_PATH "mindspore/core/utils/log_adapter.h" @@ -129,6 +131,8 @@ enum SubModuleId : int { #define SUBMODULE_ID mindspore::SubModuleId::SM_ME #endif +const char *EnumStrForMsLogLevel(MsLogLevel level); + #if defined(_WIN32) || defined(_WIN64) extern int g_ms_submodule_log_levels[] __attribute__((dllexport)); #else diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt new file mode 100644 index 00000000000..fd70ace0fe1 --- /dev/null +++ b/mindspore/lite/CMakeLists.txt @@ -0,0 +1,119 @@ +cmake_minimum_required(VERSION 3.14) +project (Lite) + +if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0) + message(FATAL_ERROR "GCC vesion ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0") +endif () + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") +set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../..) +set(CORE_DIR ${TOP_DIR}/mindspore/core) +set(CCSRC_DIR ${TOP_DIR}/mindspore/ccsrc) +include_directories(${TOP_DIR}) +include_directories(${CORE_DIR}) +include_directories(${CCSRC_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${TOP_DIR}/third_party) +include_directories(${TOP_DIR}/third_party/flatbuffers/include) + +include(${TOP_DIR}/cmake/utils.cmake) +include(${TOP_DIR}/cmake/external_libs/json.cmake) +include(${TOP_DIR}/cmake/dependency_securec.cmake) +set(CMAKE_VERBOSE_MAKEFILE on) +add_compile_definitions(USE_ANDROID_LOG) +add_compile_definitions(NO_DLIB) +add_compile_options(-fPIC) + +option(BUILD_DEVICE "if build device" on) +option(SUPPORT_TRAIN "if build for on-device train" off) +option(PLATFORM_ARM64 "if build device for arm64" off) +option(PLATFORM_ARM32 "if build device for arm32" off) +option(BUILD_CONVERTER "if build converter" on) +option(ENABLE_FP16 "if build fp16 ops" off) +option(SUPPORT_GPU "if support gpu" off) +option(OFFLINE_COMPILE "if offline compile OpenCL kernel" off) + +if (BUILD_DEVICE) + add_compile_definitions(BUILD_DEVICE) +endif() +if (SUPPORT_TRAIN) + add_compile_definitions(SUPPORT_TRAIN) +endif() +if (ENABLE_NEON) + add_compile_definitions(ENABLE_NEON) +endif () +if (ENABLE_FP16) + add_compile_definitions(ENABLE_FP16) +endif () +if (SUPPORT_GPU) + add_definitions(-DUSE_OPENCL_WRAPPER) + add_definitions(-DMS_OPENCL_PROFILE=false) + add_definitions(-DCL_HPP_TARGET_OPENCL_VERSION=200) + add_compile_definitions(SUPPORT_GPU) + if(OFFLINE_COMPILE) + add_compile_definitions(PROGRAM_WITH_IL) + endif() + include_directories(${TOP_DIR}/third_party/OpenCL-Headers) + include_directories(${TOP_DIR}/third_party/OpenCL-CLHPP/include) +endif() + +set(ANF_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/../core/ir/meta_tensor.cc + ${CCSRC_DIR}/gvar/logging_level.cc + ${CCSRC_DIR}/gvar/typeid_manager.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../core/base/base.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../core/utils/log_adapter.cc + ) +if (BUILD_CONVERTER) + if (PLATFORM_ARM64 OR PLATFORM_ARM32) + MESSAGE(FATAL_ERROR "Cannot build converter in arm platform") + endif() + find_package(Python3 3.7 COMPONENTS Interpreter Development) + if(Python3_FOUND) + set(PYTHON_INCLUDE_DIRS "${Python3_INCLUDE_DIRS}") + set(PYTHON_LIBRARIES "${Python3_LIBRARIES}") + if (WIN32) + if (Python3_DIR) + message("Python3_DIR set already: " ${Python3_DIR}) + else() + string(LENGTH ${PYTHON_LIBRARIES} PYTHON_LIBRARIES_LEN) + string(LENGTH "libpythonxx.a" Python3_NAME_LEN) + math(EXPR Python3_DIR_LEN ${PYTHON_LIBRARIES_LEN}-${Python3_NAME_LEN}) + string(SUBSTRING ${Python3_LIBRARIES} 0 ${Python3_DIR_LEN} Python3_DIR) + message("Python3_DIR: " ${Python3_DIR}) + endif() + link_directories(${Python3_DIR}) + endif() + else() + find_python_package(py_inc py_lib) + set(PYTHON_INCLUDE_DIRS "${py_inc}") + set(PYTHON_LIBRARIES "${py_lib}") + endif() + include_directories(${PYTHON_INCLUDE_DIRS}) + include(${TOP_DIR}/cmake/external_libs/pybind11.cmake) + include(${TOP_DIR}/cmake/external_libs/eigen.cmake) + include_directories(${TOP_DIR}/third_party/protobuf/build/include) + link_directories(${TOP_DIR}/third_party/protobuf/build/lib) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter) + add_subdirectory(src/common/anf_exporter) +endif() + +if (BUILD_DEVICE) + if (PLATFORM_ARM32 OR PLATFORM_ARM64) + if (NOT DEFINED ENV{ANDROID_NDK}) + message(FATAL_ERROR "env ANDROID_NDK should be setted for ARM compile") + endif() + add_compile_definitions(ENABLE_ARM) + endif() + if (PLATFORM_ARM32) + add_definitions(-mfloat-abi=softfp -mfpu=neon) + add_compile_definitions(ENABLE_ARM32) + endif() + if (PLATFORM_ARM64) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16") + add_compile_definitions(ENABLE_ARM64) + endif() + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/benchmark) +# add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/test) +endif() diff --git a/mindspore/lite/build.sh b/mindspore/lite/build.sh new file mode 100755 index 00000000000..15243f98c98 --- /dev/null +++ b/mindspore/lite/build.sh @@ -0,0 +1,272 @@ +#!/usr/bin/env bash + +set -e + +CUR_DIR=$(cd "$(dirname $0)"; pwd) +BASE_DIR=${CUR_DIR}/../../ + +usage() +{ + echo "Usage:" + echo "bash build.sh [-d] [-a arm64|arm32] [-j[n]] [-m] [-f] [-g] [-c] [-s] [-o]" + echo "" + echo "Options:" + echo " -d Enable Debug" + echo " -c Enable compile converter, default off" + echo " -m Enable Incremental compilation" + echo " -a Select ARM platform, default off" + echo " -j[n] Set the threads when building, default: -j8" + echo " -f Compile fp16 ops" + echo " -g Enable gpu compile" + echo " -s Support train" + echo " -o Offline compile OpenCL kernel" +} + +checkopts() +{ + # Init default values of build options + THREAD_NUM="8" + BUILD_TYPE="Release" + BUILD_DEVICE_PLATFORM="off" + MAKE_ONLY="off" + ENABLE_FP16="off" + ENABLE_GPU="off" + ENABLE_CONVERTER="off" + SUPPORT_TRAIN="off" + OFFLINE_COMPILE="off" + + # Process the options + while getopts 'j:da:mfcsgo' opt + do + OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') + case "${opt}" in + m) + MAKE_ONLY="on" + echo "Incremental compilation" + ;; + d) + BUILD_TYPE="Debug" + echo "Build Debug version" + ;; + j) + THREAD_NUM=$OPTARG + ;; + a) + if [[ "X$OPTARG" == "Xarm64" ]]; then + BUILD_DEVICE_PLATFORM="arm64" + echo "Enable arm64" + elif [[ "X$OPTARG" == "Xarm32" ]]; then + BUILD_DEVICE_PLATFORM="arm32" + echo "Enable arm32" + else + echo "-I parameter must be arm64 or arm32" + exit 1 + fi + ;; + c) + ENABLE_CONVERTER="on" + echo "Enable converter" + ;; + s) + SUPPORT_TRAIN="on" + echo "Support train" + ;; + f) + ENABLE_FP16="on" + echo "Enable fp16" + ;; + g) + ENABLE_GPU="on" + echo "Enable gpu" + ;; + o) + OFFLINE_COMPILE="on" + echo "OpenCL kernel offline compile" + ;; + *) + echo "Unknown option ${opt}!" + usage + exit 1 + esac + done +} + +checkndk() { + if [ "${ANDROID_NDK}" ]; then + echo -e "\e[31mANDROID_NDK_PATH=$ANDROID_NDK \e[0m" + else + echo -e "\e[31mplease set ANDROID_NDK_PATH in environment variable for example: export ANDROID_NDK=/root/usr/android-ndk-r16b/ \e[0m" + exit 1 + fi +} + +gene_flatbuffer() { + FLAT_DIR="${BASE_DIR}/mindspore/lite/schema" + cd ${FLAT_DIR} && rm -rf "${FLAT_DIR}/inner" && mkdir -p "${FLAT_DIR}/inner" + find . -name "*.fbs" -print0 | xargs -0 "${FLATC}" -c -b + find . -name "*.fbs" -print0 | xargs -0 "${FLATC}" -c -b --reflect-types --gen-mutable --reflect-names --gen-object-api -o "${FLAT_DIR}/inner" + + FLAT_DIR="${BASE_DIR}/mindspore/lite/tools/converter/parser/tflite" + cd ${FLAT_DIR} + find . -name "*.fbs" -print0 | xargs -0 "${FLATC}" -c -b --reflect-types --gen-mutable --reflect-names --gen-object-api -o "${FLAT_DIR}/" +} + +build_flatbuffer() { + cd ${BASE_DIR} + FLATC="${BASE_DIR}"/third_party/flatbuffers/build/flatc + if [[ ! -f "${FLATC}" ]]; then + git submodule update --init --recursive third_party/flatbuffers + cd ${BASE_DIR}/third_party/flatbuffers + rm -rf build && mkdir -pv build && cd build && cmake .. && make -j$THREAD_NUM + gene_flatbuffer + fi + if [[ "${MAKE_ONLY}" == "off" ]]; then + gene_flatbuffer + fi +} + +gene_protobuf() { + PROTO_SRC_DIR="${BASE_DIR}/mindspore/lite/tools/converter/parser/caffe" + find ${PROTO_SRC_DIR} -name "*.proto" -print0 | xargs -0 "${PROTOC}" -I"${PROTO_SRC_DIR}" --cpp_out="${PROTO_SRC_DIR}" + PROTO_SRC_DIR="${BASE_DIR}/mindspore/lite/tools/converter/parser/onnx" + find ${PROTO_SRC_DIR} -name "*.proto" -print0 | xargs -0 "${PROTOC}" -I"${PROTO_SRC_DIR}" --cpp_out="${PROTO_SRC_DIR}" +} + +build_protobuf() { + cd ${BASE_DIR} + PROTOC="${BASE_DIR}"/third_party/protobuf/build/bin/protoc + if [[ ! -f "${PROTOC}" ]]; then + git submodule update --init --recursive third_party/protobuf + cd ${BASE_DIR}/third_party/protobuf + rm -rf build && mkdir -pv build && ./autogen.sh + ./configure --prefix=${BASE_DIR}/third_party/protobuf/build + make clean && make -j$THREAD_NUM && make install + gene_protobuf + fi + if [[ "${MAKE_ONLY}" == "off" ]]; then + gene_protobuf + fi +} + +build_gtest() { + cd ${BASE_DIR} + git submodule update --init --recursive third_party/googletest +} + +gene_clhpp() { + CL_SRC_DIR="${BASE_DIR}/mindspore/lite/src/runtime/kernel/opencl/cl" + for sub_dir in "${CL_SRC_DIR}"/* + do + data_type="$(basename ${sub_dir})" + if [ ! -d ${CL_SRC_DIR}/${data_type} ]; then + continue + fi + cd ${CL_SRC_DIR}/${data_type} + rm -rf *.inc + echo "$(cd "$(dirname $0)"; pwd)" + for file_path in "${CL_SRC_DIR}/${data_type}"/* + do + file="$(basename ${file_path})" + inc_file=`echo ${CL_SRC_DIR}/${data_type}/${file} | sed 's/$/.inc/'` + sed 's/^/\"/;s/$/ \\n\" \\/' ${CL_SRC_DIR}/${data_type}/${file} > ${inc_file} + kernel_name=`echo ${file} | sed s'/.\{3\}$//'` + sed -i "1i\static const char *${kernel_name}_source_${data_type} =\"\\n\" \\" ${inc_file} + sed -i '$a\;' ${inc_file} + done + done +} + +gene_ocl_program() { + CL_SRC_DIR="${BASE_DIR}/mindspore/lite/src/runtime/kernel/opencl/cl" + SPIRV_DIR=build/spirv + rm -rf ${SPIRV_DIR} + mkdir -pv ${SPIRV_DIR} + for sub_dir in "${CL_SRC_DIR}"/* + do + data_type="$(basename ${sub_dir})" + if [ ! -d ${CL_SRC_DIR}/${data_type} ]; then + continue + fi + #echo $(cd "$(dirname $0)"; pwd) + for file_path in "${CL_SRC_DIR}/${data_type}"/* + do + file="$(basename ${file_path})" + if [ "${file##*.}" != "cl" ]; then + continue + fi + clang -Xclang -finclude-default-header -cl-std=CL2.0 --target=spir64-unknown-unknown -emit-llvm \ + -c -O0 -o ${SPIRV_DIR}/${file%.*}.bc ${CL_SRC_DIR}/${data_type}/${file} + done + done + + bcs=`ls ${SPIRV_DIR}/*.bc` + llvm-link ${bcs} -o ${SPIRV_DIR}/program.bc + llvm-spirv -o ${SPIRV_DIR}/program.spv ${SPIRV_DIR}/program.bc + + CL_PROGRAM_PATH="${BASE_DIR}/mindspore/lite/src/runtime/kernel/opencl/cl/program.inc" + echo "#include " > ${CL_PROGRAM_PATH} + echo "std::vector g_program_binary = {" >> ${CL_PROGRAM_PATH} + #hexdump -v -e '16/1 "0x%02x, " "\n"' ${SPIRV_DIR}/program.spv >> ${CL_PROGRAM_PATH} + hexdump -v -e '1/1 "0x%02x, "' ${SPIRV_DIR}/program.spv >> ${CL_PROGRAM_PATH} + echo "};" >> ${CL_PROGRAM_PATH} + echo "Compile SPIRV done" +} + +build_opencl() { + cd ${BASE_DIR} + git submodule update --init third_party/OpenCL-Headers + git submodule update --init third_party/OpenCL-CLHPP + if [[ "${OFFLINE_COMPILE}" == "on" ]]; then + gene_ocl_program + else + gene_clhpp + fi +} + +buildlite() { + if [[ "${MAKE_ONLY}" == "off" ]]; then + cd ${CUR_DIR} + rm -rf build + mkdir -pv build + cd build + if [[ "${BUILD_DEVICE_PLATFORM}" == "arm64" ]]; then + checkndk + cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \ + -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="arm64-v8a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \ + -DANDROID_STL="c++_shared" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ + -DBUILD_DEVICE=on -DPLATFORM_ARM64=on -DBUILD_CONVERTER=off -DENABLE_NEON=on -DENABLE_FP16="${ENABLE_FP16}" \ + -DSUPPORT_GPU=${ENABLE_GPU} -DOFFLINE_COMPILE=${OFFLINE_COMPILE} .. + elif [[ "${BUILD_DEVICE_PLATFORM}" == "arm32" ]]; then + checkndk + cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \ + -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="clang" \ + -DANDROID_STL="c++_shared" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ + -DBUILD_DEVICE=on -DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} -DBUILD_CONVERTER=off \ + -DSUPPORT_GPU=${ENABLE_GPU} -DOFFLINE_COMPILE=${OFFLINE_COMPILE} .. + else + cmake -DBUILD_DEVICE=on -DPLATFORM_ARM64=off -DBUILD_CONVERTER=${ENABLE_CONVERTER} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ + -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_GPU=${ENABLE_GPU} -DOFFLINE_COMPILE=${OFFLINE_COMPILE} .. + fi + else + cd ${CUR_DIR}/build + fi + VERBOSE=2 make -j$THREAD_NUM +} + +echo "---------------- mindspore lite: build start ----------------" +checkopts "$@" +build_flatbuffer +if [[ "${ENABLE_CONVERTER}" == "on" ]]; then + build_protobuf +fi +if [[ "${ENABLE_GPU}" == "on" ]]; then + build_opencl +fi +build_gtest +buildlite +COMPILE_RET=$? +if [[ "${COMPILE_RET}" -ne 0 ]]; then + echo "---------------- mindspore lite: build failed ----------------" +else + echo "---------------- mindspore lite: build success ----------------" +fi diff --git a/mindspore/lite/cmake-build-cloud/Lite.cbp b/mindspore/lite/cmake-build-cloud/Lite.cbp new file mode 100644 index 00000000000..c33b0880ed1 --- /dev/null +++ b/mindspore/lite/cmake-build-cloud/Lite.cbp @@ -0,0 +1,4235 @@ + + + + + + diff --git a/mindspore/lite/cmake-build-cloud/googletest/CTestTestfile.cmake b/mindspore/lite/cmake-build-cloud/googletest/CTestTestfile.cmake new file mode 100644 index 00000000000..08108fdcf45 --- /dev/null +++ b/mindspore/lite/cmake-build-cloud/googletest/CTestTestfile.cmake @@ -0,0 +1,7 @@ +# CMake generated Testfile for +# Source directory: /mnt/data/workspace/OpenAI/Huawei/mindspore/third_party/googletest +# Build directory: /mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/cmake-build-cloud/googletest +# +# This file includes the relevant testing commands required for +# testing this directory and lists subdirectories to be tested as well. +subdirs("googlemock") diff --git a/mindspore/lite/cmake-build-cloud/googletest/googlemock/CTestTestfile.cmake b/mindspore/lite/cmake-build-cloud/googletest/googlemock/CTestTestfile.cmake new file mode 100644 index 00000000000..4b7b83c2412 --- /dev/null +++ b/mindspore/lite/cmake-build-cloud/googletest/googlemock/CTestTestfile.cmake @@ -0,0 +1,7 @@ +# CMake generated Testfile for +# Source directory: /mnt/data/workspace/OpenAI/Huawei/mindspore/third_party/googletest/googlemock +# Build directory: /mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/cmake-build-cloud/googletest/googlemock +# +# This file includes the relevant testing commands required for +# testing this directory and lists subdirectories to be tested as well. +subdirs("gtest") diff --git a/mindspore/lite/cmake-build-cloud/googletest/googlemock/gmock.cbp b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gmock.cbp new file mode 100644 index 00000000000..d4b5f94e6fd --- /dev/null +++ b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gmock.cbp @@ -0,0 +1,517 @@ + + + + + + diff --git a/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/CTestTestfile.cmake b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/CTestTestfile.cmake new file mode 100644 index 00000000000..fe5ea99621e --- /dev/null +++ b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/CTestTestfile.cmake @@ -0,0 +1,6 @@ +# CMake generated Testfile for +# Source directory: /mnt/data/workspace/OpenAI/Huawei/mindspore/third_party/googletest/googletest +# Build directory: /mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest +# +# This file includes the relevant testing commands required for +# testing this directory and lists subdirectories to be tested as well. diff --git a/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/GTestConfig.cmake b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/GTestConfig.cmake new file mode 100644 index 00000000000..0ee9ec8f9f5 --- /dev/null +++ b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/GTestConfig.cmake @@ -0,0 +1,33 @@ + +####### Expanded from @PACKAGE_INIT@ by configure_package_config_file() ####### +####### Any changes to this file will be overwritten by the next CMake run #### +####### The input file was Config.cmake.in ######## + +get_filename_component(PACKAGE_PREFIX_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) + +macro(set_and_check _var _file) + set(${_var} "${_file}") + if(NOT EXISTS "${_file}") + message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !") + endif() +endmacro() + +macro(check_required_components _NAME) + foreach(comp ${${_NAME}_FIND_COMPONENTS}) + if(NOT ${_NAME}_${comp}_FOUND) + if(${_NAME}_FIND_REQUIRED_${comp}) + set(${_NAME}_FOUND FALSE) + endif() + endif() + endforeach() +endmacro() + +#################################################################################### +include(CMakeFindDependencyMacro) +if (ON) + set(THREADS_PREFER_PTHREAD_FLAG ON) + find_dependency(Threads) +endif() + +include("${CMAKE_CURRENT_LIST_DIR}/GTestTargets.cmake") +check_required_components("") diff --git a/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/GTestConfigVersion.cmake b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/GTestConfigVersion.cmake new file mode 100644 index 00000000000..b12397d6586 --- /dev/null +++ b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/GTestConfigVersion.cmake @@ -0,0 +1,37 @@ +# This is a basic version file for the Config-mode of find_package(). +# It is used by write_basic_package_version_file() as input file for configure_file() +# to create a version-file which can be installed along a config.cmake file. +# +# The created file sets PACKAGE_VERSION_EXACT if the current version string and +# the requested version string are exactly the same and it sets +# PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version. +# The variable CVF_VERSION must be set before calling configure_file(). + +set(PACKAGE_VERSION "1.9.0") + +if(PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION) + set(PACKAGE_VERSION_COMPATIBLE FALSE) +else() + set(PACKAGE_VERSION_COMPATIBLE TRUE) + if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION) + set(PACKAGE_VERSION_EXACT TRUE) + endif() +endif() + + +# if the installed project requested no architecture check, don't perform the check +if("FALSE") + return() +endif() + +# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it: +if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "") + return() +endif() + +# check that the installed version has the same 32/64bit-ness as the one which is currently searching: +if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8") + math(EXPR installedBits "8 * 8") + set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)") + set(PACKAGE_VERSION_UNSUITABLE TRUE) +endif() diff --git a/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/gmock.pc b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/gmock.pc new file mode 100644 index 00000000000..d4242cfa66f --- /dev/null +++ b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/gmock.pc @@ -0,0 +1,9 @@ +libdir=/usr/local/lib +includedir=/usr/local/include + +Name: gmock +Description: GoogleMock (without main() function) +Version: 1.9.0 +URL: https://github.com/google/googletest +Libs: -L${libdir} -lgmock -pthread +Cflags: -I${includedir} -DGTEST_HAS_PTHREAD=1 -pthread diff --git a/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/gmock_main.pc b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/gmock_main.pc new file mode 100644 index 00000000000..2da4fbcc017 --- /dev/null +++ b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/gmock_main.pc @@ -0,0 +1,9 @@ +libdir=/usr/local/lib +includedir=/usr/local/include + +Name: gmock_main +Description: GoogleMock (with main() function) +Version: 1.9.0 +URL: https://github.com/google/googletest +Libs: -L${libdir} -lgmock_main -pthread +Cflags: -I${includedir} -DGTEST_HAS_PTHREAD=1 -pthread diff --git a/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/gtest.pc b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/gtest.pc new file mode 100644 index 00000000000..a9931b8be97 --- /dev/null +++ b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/gtest.pc @@ -0,0 +1,9 @@ +libdir=/usr/local/lib +includedir=/usr/local/include + +Name: gtest +Description: GoogleTest (without main() function) +Version: 1.9.0 +URL: https://github.com/google/googletest +Libs: -L${libdir} -lgtest -pthread +Cflags: -I${includedir} -DGTEST_HAS_PTHREAD=1 -pthread diff --git a/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/gtest_main.pc b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/gtest_main.pc new file mode 100644 index 00000000000..57948c76ee1 --- /dev/null +++ b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/generated/gtest_main.pc @@ -0,0 +1,10 @@ +libdir=/usr/local/lib +includedir=/usr/local/include + +Name: gtest_main +Description: GoogleTest (with main() function) +Version: 1.9.0 +URL: https://github.com/google/googletest +Requires: gtest +Libs: -L${libdir} -lgtest_main -pthread +Cflags: -I${includedir} -DGTEST_HAS_PTHREAD=1 -pthread diff --git a/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/gtest.cbp b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/gtest.cbp new file mode 100644 index 00000000000..df1d2d30c9d --- /dev/null +++ b/mindspore/lite/cmake-build-cloud/googletest/googlemock/gtest/gtest.cbp @@ -0,0 +1,327 @@ + + + + + + diff --git a/mindspore/lite/cmake-build-cloud/googletest/googletest-distribution.cbp b/mindspore/lite/cmake-build-cloud/googletest/googletest-distribution.cbp new file mode 100644 index 00000000000..933a58b2371 --- /dev/null +++ b/mindspore/lite/cmake-build-cloud/googletest/googletest-distribution.cbp @@ -0,0 +1,517 @@ + + + + + + diff --git a/mindspore/lite/cmake-build-cloud/src/runtime/kernel/arm/opclib/optimize.cbp b/mindspore/lite/cmake-build-cloud/src/runtime/kernel/arm/opclib/optimize.cbp new file mode 100644 index 00000000000..2e3dd81d819 --- /dev/null +++ b/mindspore/lite/cmake-build-cloud/src/runtime/kernel/arm/opclib/optimize.cbp @@ -0,0 +1,112 @@ + + + + + + diff --git a/mindspore/lite/cmake-build-minnie/Lite.cbp b/mindspore/lite/cmake-build-minnie/Lite.cbp new file mode 100644 index 00000000000..686f7908508 --- /dev/null +++ b/mindspore/lite/cmake-build-minnie/Lite.cbp @@ -0,0 +1,4235 @@ + + + + + + diff --git a/mindspore/lite/cmake-build-minnie/googletest/CTestTestfile.cmake b/mindspore/lite/cmake-build-minnie/googletest/CTestTestfile.cmake new file mode 100644 index 00000000000..356aeec8f07 --- /dev/null +++ b/mindspore/lite/cmake-build-minnie/googletest/CTestTestfile.cmake @@ -0,0 +1,7 @@ +# CMake generated Testfile for +# Source directory: /mnt/data/workspace/OpenAI/Huawei/mindspore/third_party/googletest +# Build directory: /mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/cmake-build-minnie/googletest +# +# This file includes the relevant testing commands required for +# testing this directory and lists subdirectories to be tested as well. +subdirs("googlemock") diff --git a/mindspore/lite/cmake-build-minnie/googletest/googlemock/CTestTestfile.cmake b/mindspore/lite/cmake-build-minnie/googletest/googlemock/CTestTestfile.cmake new file mode 100644 index 00000000000..59a3fc27a51 --- /dev/null +++ b/mindspore/lite/cmake-build-minnie/googletest/googlemock/CTestTestfile.cmake @@ -0,0 +1,7 @@ +# CMake generated Testfile for +# Source directory: /mnt/data/workspace/OpenAI/Huawei/mindspore/third_party/googletest/googlemock +# Build directory: /mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/cmake-build-minnie/googletest/googlemock +# +# This file includes the relevant testing commands required for +# testing this directory and lists subdirectories to be tested as well. +subdirs("gtest") diff --git a/mindspore/lite/cmake-build-minnie/googletest/googlemock/gmock.cbp b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gmock.cbp new file mode 100644 index 00000000000..5ed2a552da5 --- /dev/null +++ b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gmock.cbp @@ -0,0 +1,517 @@ + + + + + + diff --git a/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/CTestTestfile.cmake b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/CTestTestfile.cmake new file mode 100644 index 00000000000..caf7474c1db --- /dev/null +++ b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/CTestTestfile.cmake @@ -0,0 +1,6 @@ +# CMake generated Testfile for +# Source directory: /mnt/data/workspace/OpenAI/Huawei/mindspore/third_party/googletest/googletest +# Build directory: /mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest +# +# This file includes the relevant testing commands required for +# testing this directory and lists subdirectories to be tested as well. diff --git a/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/GTestConfig.cmake b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/GTestConfig.cmake new file mode 100644 index 00000000000..0ee9ec8f9f5 --- /dev/null +++ b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/GTestConfig.cmake @@ -0,0 +1,33 @@ + +####### Expanded from @PACKAGE_INIT@ by configure_package_config_file() ####### +####### Any changes to this file will be overwritten by the next CMake run #### +####### The input file was Config.cmake.in ######## + +get_filename_component(PACKAGE_PREFIX_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) + +macro(set_and_check _var _file) + set(${_var} "${_file}") + if(NOT EXISTS "${_file}") + message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !") + endif() +endmacro() + +macro(check_required_components _NAME) + foreach(comp ${${_NAME}_FIND_COMPONENTS}) + if(NOT ${_NAME}_${comp}_FOUND) + if(${_NAME}_FIND_REQUIRED_${comp}) + set(${_NAME}_FOUND FALSE) + endif() + endif() + endforeach() +endmacro() + +#################################################################################### +include(CMakeFindDependencyMacro) +if (ON) + set(THREADS_PREFER_PTHREAD_FLAG ON) + find_dependency(Threads) +endif() + +include("${CMAKE_CURRENT_LIST_DIR}/GTestTargets.cmake") +check_required_components("") diff --git a/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/GTestConfigVersion.cmake b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/GTestConfigVersion.cmake new file mode 100644 index 00000000000..b12397d6586 --- /dev/null +++ b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/GTestConfigVersion.cmake @@ -0,0 +1,37 @@ +# This is a basic version file for the Config-mode of find_package(). +# It is used by write_basic_package_version_file() as input file for configure_file() +# to create a version-file which can be installed along a config.cmake file. +# +# The created file sets PACKAGE_VERSION_EXACT if the current version string and +# the requested version string are exactly the same and it sets +# PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version. +# The variable CVF_VERSION must be set before calling configure_file(). + +set(PACKAGE_VERSION "1.9.0") + +if(PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION) + set(PACKAGE_VERSION_COMPATIBLE FALSE) +else() + set(PACKAGE_VERSION_COMPATIBLE TRUE) + if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION) + set(PACKAGE_VERSION_EXACT TRUE) + endif() +endif() + + +# if the installed project requested no architecture check, don't perform the check +if("FALSE") + return() +endif() + +# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it: +if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "") + return() +endif() + +# check that the installed version has the same 32/64bit-ness as the one which is currently searching: +if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8") + math(EXPR installedBits "8 * 8") + set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)") + set(PACKAGE_VERSION_UNSUITABLE TRUE) +endif() diff --git a/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/gmock.pc b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/gmock.pc new file mode 100644 index 00000000000..d4242cfa66f --- /dev/null +++ b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/gmock.pc @@ -0,0 +1,9 @@ +libdir=/usr/local/lib +includedir=/usr/local/include + +Name: gmock +Description: GoogleMock (without main() function) +Version: 1.9.0 +URL: https://github.com/google/googletest +Libs: -L${libdir} -lgmock -pthread +Cflags: -I${includedir} -DGTEST_HAS_PTHREAD=1 -pthread diff --git a/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/gmock_main.pc b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/gmock_main.pc new file mode 100644 index 00000000000..2da4fbcc017 --- /dev/null +++ b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/gmock_main.pc @@ -0,0 +1,9 @@ +libdir=/usr/local/lib +includedir=/usr/local/include + +Name: gmock_main +Description: GoogleMock (with main() function) +Version: 1.9.0 +URL: https://github.com/google/googletest +Libs: -L${libdir} -lgmock_main -pthread +Cflags: -I${includedir} -DGTEST_HAS_PTHREAD=1 -pthread diff --git a/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/gtest.pc b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/gtest.pc new file mode 100644 index 00000000000..a9931b8be97 --- /dev/null +++ b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/gtest.pc @@ -0,0 +1,9 @@ +libdir=/usr/local/lib +includedir=/usr/local/include + +Name: gtest +Description: GoogleTest (without main() function) +Version: 1.9.0 +URL: https://github.com/google/googletest +Libs: -L${libdir} -lgtest -pthread +Cflags: -I${includedir} -DGTEST_HAS_PTHREAD=1 -pthread diff --git a/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/gtest_main.pc b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/gtest_main.pc new file mode 100644 index 00000000000..57948c76ee1 --- /dev/null +++ b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/generated/gtest_main.pc @@ -0,0 +1,10 @@ +libdir=/usr/local/lib +includedir=/usr/local/include + +Name: gtest_main +Description: GoogleTest (with main() function) +Version: 1.9.0 +URL: https://github.com/google/googletest +Requires: gtest +Libs: -L${libdir} -lgtest_main -pthread +Cflags: -I${includedir} -DGTEST_HAS_PTHREAD=1 -pthread diff --git a/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/gtest.cbp b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/gtest.cbp new file mode 100644 index 00000000000..b79931c4482 --- /dev/null +++ b/mindspore/lite/cmake-build-minnie/googletest/googlemock/gtest/gtest.cbp @@ -0,0 +1,327 @@ + + + + + + diff --git a/mindspore/lite/cmake-build-minnie/googletest/googletest-distribution.cbp b/mindspore/lite/cmake-build-minnie/googletest/googletest-distribution.cbp new file mode 100644 index 00000000000..429f759103f --- /dev/null +++ b/mindspore/lite/cmake-build-minnie/googletest/googletest-distribution.cbp @@ -0,0 +1,517 @@ + + + + + + diff --git a/mindspore/lite/cmake-build-minnie/src/runtime/kernel/arm/opclib/optimize.cbp b/mindspore/lite/cmake-build-minnie/src/runtime/kernel/arm/opclib/optimize.cbp new file mode 100644 index 00000000000..07af7ee7791 --- /dev/null +++ b/mindspore/lite/cmake-build-minnie/src/runtime/kernel/arm/opclib/optimize.cbp @@ -0,0 +1,112 @@ + + + + + + diff --git a/mindspore/lite/include/context.h b/mindspore/lite/include/context.h new file mode 100644 index 00000000000..6da548ac3a2 --- /dev/null +++ b/mindspore/lite/include/context.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_INCLUDE_CONTEXT_H_ +#define MINDSPORE_LITE_INCLUDE_CONTEXT_H_ + +#include +#include +#include "include/ms_tensor.h" + +namespace mindspore::lite { +class Allocator; +enum CpuBindMode { + MID_CPU = -1, /**< bind mid cpu first */ + HIGHER_CPU = 1, /**< bind higher cpu first */ + NO_BIND = 0 /**< no bind */ +}; + +typedef enum { DT_CPU, DT_GPU, DT_NPU } DeviceType; + +// brief NPUContext defined by MindSpore predict +typedef struct { + int freq{3}; + int fmkType{0}; + int modelType{0}; + int deviceType{0}; + std::string modelName = "default"; +} NPUContext; + +// brief DeviceContext defined by MindSpore predict +typedef struct { + DeviceType type; + // DLContext primary; + NPUContext npuCtx; +} DeviceContext; + +// brief Context defined by MindSpore predict +class MS_API Context { + public: + // brief Constructor of MindSpore predict context using default value for parameters + // + // return Instance of MindSpore predict context. + Context(); + + // brief Constructor of MindSpore predict context using input value for parameters + // + // param[in] threadNum Define the threadNum during the runtime. + // param[in] allocator Define the allocator for malloc. + // param[in] deviceCtx Define device information during the runtime. + Context(int threadNum, std::shared_ptr allocator, DeviceContext deviceCtx); + + // brief Destructor of MindSpore predict context + virtual ~Context(); + + public: + DeviceContext deviceCtx; + int threadNum = 2; + std::shared_ptr allocator; + CpuBindMode cpuBindMode = MID_CPU; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_INCLUDE_CONTEXT_H_ + diff --git a/mindspore/lite/include/errorcode.h b/mindspore/lite/include/errorcode.h new file mode 100644 index 00000000000..2cdd4659dea --- /dev/null +++ b/mindspore/lite/include/errorcode.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_INCLUDE_ERRORCODE_H_ +#define MINDSPORE_LITE_INCLUDE_ERRORCODE_H_ + +namespace mindspore { +namespace lite { +using STATUS = int; + +/* Success */ +constexpr int RET_OK = 0; /**< No error occurs. */ + +/* Common error code, range: [-1, -100]*/ +constexpr int RET_ERROR = -1; /**< Common error code. */ +constexpr int RET_NULL_PTR = -2; /**< NULL pointer returned.*/ +constexpr int RET_PARAM_INVALID = -3; /**< Invalid parameter.*/ +constexpr int RET_NO_CHANGE = -4; /**< No change. */ +constexpr int RET_SUCCESS_EXIT = -5; /**< No error but exit. */ +constexpr int RET_MEMORY_FAILED = -6; /**< Create memory failed. */ + +/* Executor error code, range: [-101,-200] */ +constexpr int RET_OUT_OF_TENSOR_RANGE = -101; /**< Failed to checking range. */ +constexpr int RET_INPUT_TENSOR_ERROR = -102; /**< Failed to checking input tensor. */ +constexpr int RET_REENTRANT_ERROR = -103; /**< Exist executor running. */ + +/* Graph error code, range: [-201,-300] */ +constexpr int RET_GRAPH_FILE_ERR = -201; /**< Failed to verify graph file. */ + +/* Node error code, range: [-301,-400] */ +constexpr int RET_NOT_FIND_OP = -301; /**< Failed to find operator. */ +constexpr int RET_INVALID_OP_NAME = -302; /**< Invalid operator name. */ +constexpr int RET_INVALID_OP_ATTR = -303; /**< Invalid operator attr. */ +constexpr int RET_OP_EXECUTE_FAILURE = -304; /**< Failed to execution operator. */ + +/* Tensor error code, range: [-401,-500] */ +constexpr int RET_FORMAT_ERR = -401; /**< Failed to checking tensor format. */ +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_INCLUDE_ERRORCODE_H_ + diff --git a/mindspore/lite/include/lite_session.h b/mindspore/lite/include/lite_session.h new file mode 100644 index 00000000000..24f72c905cf --- /dev/null +++ b/mindspore/lite/include/lite_session.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_INCLUDE_LITE_SESSION_H +#define MINDSPORE_LITE_INCLUDE_LITE_SESSION_H + +#include +#include +#include +#include "include/ms_tensor.h" +#include "include/model.h" +#include "include/context.h" + +namespace mindspore { +namespace session { +class MS_API LiteSession { + public: + virtual ~LiteSession() = default; + + virtual void BindThread(bool ifBind) = 0; + + static LiteSession *CreateSession(lite::Context *context); + + virtual int CompileGraph(lite::Model *model) = 0; + + virtual std::vector GetInputs() = 0; + + virtual std::vector GetInputsByName(std::string name) = 0; + + virtual int RunGraph() = 0; + + virtual std::vector GetOutputs() = 0; + + virtual std::vector GetOutputsByName(std::string name) = 0; +}; +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H + diff --git a/mindspore/lite/include/model.h b/mindspore/lite/include/model.h new file mode 100644 index 00000000000..f53d77d3539 --- /dev/null +++ b/mindspore/lite/include/model.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_INCLUDE_MODEL_H +#define MINDSPORE_LITE_INCLUDE_MODEL_H + +#include +#include +#include +#include "schema/model_generated.h" + +namespace mindspore { +class ModelImpl; +namespace lite { +class Primitive; +class Model { + public: + static std::shared_ptr Import(const char *model_buf, size_t size); + virtual ~Model() = default; + Model() = default; + lite::Primitive *GetOp(const std::string &name) const; + const schema::MetaGraph *GetMetaGraph() const; + std::shared_ptr GetModelImpl(); + void FreeMetaGraph(); + + protected: + std::shared_ptr modelImpl = nullptr; +}; +class ModelBuilder { + public: + struct OutEdge { + std::string nodeId; + size_t outEdgeIndex; + }; + ModelBuilder() = default; + virtual ~ModelBuilder() = default; + virtual std::string AddOp(const lite::Primitive &op, const std::vector &inputs) = 0; + virtual Model *Construct(); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_INCLUDE_MODEL_H + diff --git a/mindspore/lite/include/ms_tensor.h b/mindspore/lite/include/ms_tensor.h new file mode 100644 index 00000000000..bb99ffcc730 --- /dev/null +++ b/mindspore/lite/include/ms_tensor.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_INCLUDE_MS_TENSOR_H_ +#define MINDSPORE_INCLUDE_MS_TENSOR_H_ + +#include +#include +#include +#include "ir/dtype/type_id.h" + +namespace mindspore { +#define MS_API __attribute__((visibility("default"))) +namespace tensor { +class MS_API MSTensor { + public: + MSTensor() = default; + // brief Create a MSTensor pointer. + // + // param data_type DataTypeId of tensor to be created. + // param shape Shape of tensor to be created. + // return MSTensor pointer. + static MSTensor *CreateTensor(TypeId data_type, const std::vector &shape); + + virtual ~MSTensor() = default; + + virtual TypeId data_type() const = 0; + + virtual TypeId set_data_type(const TypeId data_type) = 0; + + virtual std::vector shape() const = 0; + + virtual size_t set_shape(const std::vector &shape) = 0; + + virtual int DimensionSize(size_t index) const = 0; + // brief Get number of element in MSTensor. + // + // return Number of element in MSTensor. + virtual int ElementsNum() const = 0; + + virtual std::size_t hash() const = 0; + // brief Get byte size of data in MSTensor. + // + // return Byte size of data in MSTensor. + virtual size_t Size() const = 0; + // brief Get pointer of data in MSTensor. + // + // The data pointer can be used to both write or read data in MSTensor. + // + // return A pointer points to data in MSTensor. + virtual void *MutableData() const = 0; +}; +using MultiTensor = std::vector>>; +} // namespace tensor +} // namespace mindspore +#endif // MINDSPORE_INCLUDE_MS_TENSOR_H_ + diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs new file mode 100644 index 00000000000..060998e2463 --- /dev/null +++ b/mindspore/lite/schema/model.fbs @@ -0,0 +1,208 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +include "ops.fbs"; + +namespace mindspore.schema; + +enum NodeType: int { + ValueNode, // const + Parameter, // var + CNode // op +} + +table QuantParam { + scale: double; + zeroPoint: int; + min: double = 0; + max: double = 0; + narrowRange: bool = true; + numBits: int = 8; + inited: bool = false; +} + +table Tensor { + nodeType: NodeType; + // data type + dataType: int; + // shape + dims: [int]; + format: Format; + refCount: int; + offset: int; + data: [ubyte]; + quantParams: [QuantParam]; +} + +union PrimitiveType { + Concat, + SoftMax, + Activation, + Conv2D, + FusedBatchNorm, + CaffeBatchNorm, + BiasAdd, + Pooling, + DepthwiseConv2D, + DeDepthwiseConv2D, + Resize, + DetectionPostProcess, + FullConnection, + Mean, + DeConv2D, + Scale, + Reshape, + Eltwise, + NetOutput, + Add, + Sub, + MatMul, + StridedSlice, + Power, + Slice, + Stack, + Mul, + RealDiv, + Pad, + Maximum, + Minimum, + CaffePReLU, + LeakyReLU, + ArgMax, + ArgMin, + Exp, + Crop, + Range, + Rsqrt, + ExpandDims, + Tile, + Cast, + Shape, + Nchw2Nhwc, + Nhwc2Nchw, + QuantDTypeCast, + Split, + Permute, + FakeQuantWithMinMaxVars, + Equal, + Less, + Greater, + NotEqual, + LessEqual, + GreaterEqual, + Min, + Floor, + Abs, + Neg, + Cos, + Sin, + Sqrt, + Square, + Constant, + Log, + Tan, + Atan, + Asin, + Clip, + Transpose, + Squeeze, + Unsqueeze, + Upsample, + Dropout, + Broadcast, + BroadcastTo, + Lrn, + Prelu, + ZerosLike, + TopK, + SpaceToDepth, + SpaceToBatch, + SparseToDense, + ReverseSequence, + Rank, + Gather, + GatherNd, + Fill, + Elu, + DepthToSpace, + BatchToSpace, + AddN, + Ceil, + EmbeddingLookup, + EmbeddingLookupSparse, + FloorDiv, + FloorMod, + L2Norm, + LocalResponseNormalization, + MatrixDiag, + Reduce, + Reverse, + Round, + Select, + Scatter, + ScatterND, + Unique, + Unstack, + LogicalAnd, + LogicalOr, + LogicalXor, + LogicalNot, + OnnxInt8Quantize, + OnnxInt8Dequantize, + FakeQuantWithMinMax, + FakeQuantWithMinMaxPerChannel, + BatchNormFold, + MulFold, + AddFold, + SquaredDifference, + Flatten, + TupleGetItem, + Div, + Where, + OneHot +} + +enum QuantType: int { + QUANT_NONE, + AwareTrainning, + WeightQuant, + PostTraining +} + +table Primitive { + value: PrimitiveType; +} + +table CNode { + name: string; + nodeType: NodeType = CNode; + primitive: Primitive; + inputIndex: [uint]; + outputIndex: [uint]; + quantType: QuantType = QUANT_NONE; +} + +table MetaGraph { + name: string; + fmkType: int; // 0:tf,1:caffe + inputIndex: [uint]; + outputIndex: [uint]; + mempoolSize: uint; + nodes: [CNode]; + allTensors: [Tensor]; // weight + input + output +} + +root_type MetaGraph; diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs new file mode 100644 index 00000000000..45b7558625d --- /dev/null +++ b/mindspore/lite/schema/ops.fbs @@ -0,0 +1,719 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace mindspore.schema; + +enum ResizeMethod: byte { + UNKNOW = -1, + BILINEAR = 0, + NEAREST_NEIGHBOR = 1 +} + +enum Format : int { + NCHW = 0, + NHWC, + NHWC4, + HWKC, + HWCK, + KCHW, + CKHW, + KHWC, + CHWK, + NC4HW4 = 100, + NUM_OF_FORMAT +} + +enum ActivationType : byte { + NO_ACTIVATION = 0, + RELU = 1, + SIGMOID = 2, + RELU6 = 3, + ELU = 4, + LEAKY_RELU = 5, + ABS = 6, + RELU1 = 7, + SOFTSIGN = 8, + SOFTPLUS = 9, + TANH = 10, + SELU = 11, + HSWISH = 12, + HSIGMOID = 13, + THRESHOLDRELU = 14, + LINEAR = 15, + UNKNOW = 16 +} + +enum ReduceType : byte { + REDUCE_MAX = 0, + REDUCE_MEAN = 1, + REDUCE_ALL = 2, + REDUCE_ANY = 3, + REDUCE_LOG_SUM_EXP = 4, + REDUCE_PROD = 5, + REDUCE_SUM = 6, + UNKNOW = 7 +} + +enum PoolMode : byte { + MAX_POOLING = 0, + MEAN_POOLING = 1, +} + +enum EltwiseMode : byte { + PROD = 0, + SUM = 1, + MAXIMUM = 2, + UNKNOW = 3 +} + +enum PadMode : byte { + NOTSET = 0, + SAME = 1, + VALID = 2, + CAFFE = 4 +} + +enum RoundMode : byte { + FLOOR = 0, + CEIL = 1 +} + +enum PaddingMode : byte { + CONSTANT = 0, + REFLECT = 1, + SYMMETRIC = 2, + MODE_RESERVED = 3 +} + +table Pad { + paddingmode: PaddingMode; + paddings: [int]; +} + +table Maximum { +} + +table Minimum { +} + +table Flatten { +} + +table Concat { + axis: int; + n: int; +} + +table SoftMax { + axis: int; +} + +table Activation { + type: ActivationType = 0; +} + +table Conv2D { + format: Format = 0; + group: int; + channelIn: int; + channelOut: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + activationType: ActivationType = 0; +} + +table FusedBatchNorm { + epsilon: float = 0.00001; // eg. epsilon=0.001 + momentum: float = 0.9; + spatial: int = 1; +} + +table CaffeBatchNorm { + epsilon: float; // eg. epsilon=0.001 +} + +table Shape { +} + +table Nchw2Nhwc { + +} + +table Nhwc2Nchw { + +} + +table FakeQuantWithMinMaxVars { + narrowRange: bool; + numBits: int; +} + +table BiasAdd { + axis: [int]; +} + +table Pooling { + format: Format = 0; + poolingMode: PoolMode; + global: bool = false; + windowW: int; + windowH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + roundMode: RoundMode; +} + +table DepthwiseConv2D { + format: Format = 0; + channelIn: int; + channelMultiplier: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + activationType: ActivationType = 0; +} + +table DeDepthwiseConv2D { + format: Format = 0; + channelIn: int; + channelMultiplier: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + activationType: ActivationType = 0; +} + + +table Resize { + format: Format = 0; + method: ResizeMethod; + newHeight: long; + newWidth: long; + alignCorners: bool = false; + preserveAspectRatio: bool = false; +} + +table DetectionPostProcess { + format: Format = 0; + inputSize: int; + hScale: float; + wScale: float; + xScale: float; + yScale: float; + NmsIouThreshold: float; + NmsScoreThreshold: float; + MaxDetections: long; + DetectionsPreClass: long; + MaxClassesPreDetection: long; + NumClasses: long; + UseRegularNms: bool; +} + +table FullConnection { + hasBias: bool; + axis: int; +} + +// Mean(input_tensor, axis, keep_dims) +table Mean { + axis: [int]; + keepDims: bool = false; +} + +table DeConv2D { + format: Format = 0; + group: int; + channelIn: int; + channelOut: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + activationType: ActivationType = 0; +} + +table Scale { + format: Format = 0; +} + +table Eltwise { + mode: EltwiseMode; +} + +table Add { +} + +table Sub { +} + +table Mul { +} + +table Div { +} + +table RealDiv { +} + +table Rsqrt { +} + +table Equal { +} + +table Less { +} + +table Greater { +} + +table NotEqual { +} + +table LessEqual { +} + +table GreaterEqual { +} + +table Min { +} + +table Slice { + format: Format = 0; + begin: [int]; + size: [int]; +} + +table Floor { +} + +table Abs { +} + +table Neg { +} + +table Exp { +} + +table Cos { +} + +table Sin { +} + +table Sqrt { +} + +table Square { +} + +table Ceil { +} + +table Log { +} + +table Tan { +} + +table Atan { +} + +table Asin { +} + +table Reshape { + format: Format = 0; + shape: [long]; +} + +table Power { + power: float; + scale: float; + shift: float; +} + +table ArgMax { + axis: int; + outMaxValue: bool; + topK: int = 1; + keepDims: bool; + axisType: int; +} + +table ArgMin { + axis: int; + outMaxValue: bool; + topK: int = 1; + keepDims: bool; + axisType: int; +} + +table NetOutput { +} + +table MatMul { + transposeA : bool = false; + transposeB : bool = false; +} + +table CaffePReLU { + channelShared : bool = false; +} + +table LeakyReLU { + negativeSlope: float; +} + +table StridedSlice { + beginMask: int; + endMask: int; + ellipsisMask: int; + newAxisMask: int; + shrinkAxisMask: int; + begin: [int]; + end: [int]; + stride: [int]; + isScale: [int]; +} + +table Stack { + axis: int; + n: int; + isScale: [int]; +} + +table Range { + dType: int; + start: int; + limit: int; + delta: int; +} + +table ExpandDims { + dim: int; +} + +table Tile { + multiples: [int]; +} + +table Cast { + srcT: int; + dstT: int; +} + +table QuantDTypeCast { + srcT: int; + dstT: int; +} + +table Split { + numberSplit: int; + sizeSplits: [int]; + splitDim: int; +} + +table Crop { + axis : long; + offsets : [long]; +} + +table Permute { + order: [long]; +} + +table Clip { + max: float; + min: float; +} + +table Constant { +} + + +table Elu { + alpha: float = 1.0; +} + +table Broadcast { +} + +table BroadcastTo { + dst_shape: [int]; +} + +table Lrn { + alpha: float = 0.0001; + beta: float = 0.75; + bias: float = 1.0; + size: int; +} + +enum ReduceMode : byte { + ReduceMean = 0, + ReduceMax = 1, + ReduceMin = 2, + ReduceProd = 3, + ReduceSum = 4, + ReduceSumSquare = 5 +} + +table Reduce { + axes: [int]; + keepDims: int; + mode: ReduceMode; +} + +table Prelu { + slope: [float]; +} + +table Transpose { + perm: [int]; + conjugate: bool = false; +} + +table Squeeze { + axis: [int]; +} + +table Unsqueeze { + axis: [int]; +} + +table Upsample { + mode: string; + scales: [float]; +} + +table Dropout { + ratio : float = 0.5; +} + +table LocalResponseNormalization { + depth_radius: int; + bias: float; + alpha: float; + beta: float; +} + +table ZerosLike { +} + +table TopK { + k : int; + sorted : bool = true; +} + +table SpaceToDepth { + blockSize : int; + format: Format = 0; +} + +table SpaceToBatch { + blockShape : [int]; + paddings : [int]; +} + +table SparseToDense { + validateIndices: bool; +} + +table ReverseSequence { + seqAxis: int; + batchAxis: int; +} + +table Rank { +} + + +table Gather { + axis: int; + batchDims: int; +} + +table GatherNd { + batchDims: int; +} + +table Fill { + dims: [int]; +} + +table DepthToSpace { + blockSize: int; + format: Format = 0; +} + + +table BatchToSpace { + blockShape: [int]; + crops: [int]; +} + +table AddN { + N: int; +} + + +table EmbeddingLookup { + ids: [int]; + maxNorm: float; +} + +table EmbeddingLookupSparse { + spIds: [int]; + spWeights: [float]; + //combiner: Combiner=0; + maxNortm: float; +} + +table FloorDiv { +} + +table FloorMod { +} + +table L2Norm { + axis: [int]; + epsilon: float; +} + +table LogicalAnd { +} + +table LogicalOr { +} + +table LogicalXor { +} + +table LogicalNot { +} + +table MatrixDiag { + k: int; + numRows: int; + numCols: int; + paddingValue: float; +} + +table Select { +} + +table TfReduce { + type: ReduceType = 7; +} + +table Reverse { + axis: [int]; +} + +table Round { +} + +table Scatter { +} + +table ScatterND { +} + +table Unique { +} + +table Unstack { + num: int; + axis: int; +} + +table OnnxInt8Quantize { +} + +table OnnxInt8Dequantize { +} + +table FakeQuantWithMinMax { +} + +table FakeQuantWithMinMaxPerChannel { +} + +table BatchNormFold { +} + +table MulFold { +} + +table AddFold { +} + +table SquaredDifference { +} + +table TupleGetItem { +} + +table Where{ +} + +table OneHot { + axis: int; +} diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt new file mode 100644 index 00000000000..bec9eaedcc1 --- /dev/null +++ b/mindspore/lite/src/CMakeLists.txt @@ -0,0 +1,83 @@ +set(LITE_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/common/graph_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/allocator.cc + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_api.cc + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/thread_pool.cc + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/workspace_pool.cc + ${CMAKE_CURRENT_SOURCE_DIR}/ir/tensor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/context.cc + ${CMAKE_CURRENT_SOURCE_DIR}/executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel_factory.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc + ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc + ${CMAKE_CURRENT_SOURCE_DIR}/model.cc + ${CMAKE_CURRENT_SOURCE_DIR}/populate_parameter.cc + ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc + ) + +if (SUPPORT_GPU) + list(APPEND LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/subgraph_opencl_kernel.cc) + list(APPEND LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/utils.cc) +endif() + +if (SUPPORT_TRAIN) + set(ANF_SRC +# ${CCSRC_DIR}/common/trans.cc +# ${CCSRC_DIR}/utils/lite/base_ref_utils.cc +# ${CCSRC_DIR}/runtime/kernel/kernel_compiler/kernel_build_info.cc +# ${CCSRC_DIR}/session/lite/anf_runtime_algorithm_extends.cc +# ${CCSRC_DIR}/session/lite/session_basic_extends.cc +# ${CCSRC_DIR}/session/anf_runtime_algorithm.cc +# ${CCSRC_DIR}/session/session_basic.cc +# ${CCSRC_DIR}/session/kernel_graph.cc +# ${CCSRC_DIR}/session/session_factory.cc +# ${CCSRC_DIR}/device/kernel_info.cc +# ${CCSRC_DIR}/device/kernel_runtime.cc +# ${CCSRC_DIR}/device/lite/kernel_runtime_extends.cc + ) + set(PASS_SRC) + set(LITE_SRC + ${LITE_SRC} + ${ANF_SRC} + ${PASS_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/anf_importer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/import_from_meta_graph.cc + ${CMAKE_CURRENT_SOURCE_DIR}/ir/primitive_value.cc + ${CMAKE_CURRENT_SOURCE_DIR}/train/lite_kernel_runtime.cc + ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc + ${CMAKE_CURRENT_SOURCE_DIR}/train/model_impl.cc + ) +else () + set(LITE_SRC + ${LITE_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc + ${CMAKE_CURRENT_SOURCE_DIR}/model_impl.cc + ) +endif () + +if (SUPPORT_GPU) + set(LITE_SRC + ${LITE_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_allocator.cc + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_runtime.cc + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_wrapper.cc + ) +endif () + +set(ANF_SRC + ${ANF_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/ir/meta_tensor_extends.cc + ) + +add_library(mindspore-lite SHARED ${LITE_SRC} ${ANF_SRC}) +target_link_libraries(mindspore-lite + cpu_kernel_mid_ + ops_mid_ + ${SECUREC_LIBRARY} + mindspore::json + ) + +add_subdirectory(runtime/kernel/arm) +add_subdirectory(ops) + diff --git a/mindspore/lite/src/common/anf_exporter/CMakeLists.txt b/mindspore/lite/src/common/anf_exporter/CMakeLists.txt new file mode 100644 index 00000000000..352f59947a7 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/CMakeLists.txt @@ -0,0 +1,7 @@ +file(GLOB_RECURSE ANF_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + *.cc + ) +add_library(anf_exporter_mid OBJECT + ${ANF_SRC_LIST} + ) + diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc new file mode 100644 index 00000000000..b5a6799db0b --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc @@ -0,0 +1,263 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/anf_exporter/anf_exporter.h" +#include +#include +#include +#include +#include "abstract/abstract_value.h" +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "src/param_value_lite.h" +#include "mindspore/core/ir/primitive.h" +#include "src/ir/primitive_t_value.h" +#include "base/core_ops.h" + +namespace mindspore::lite { +schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { + auto cnodes = funcGraph->GetOrderedCnodes(); + auto metaGraphT = std::make_unique(); + for (const auto &cnode : cnodes) { + auto primitive = GetValueNode(cnode->input(0)); + if (primitive != nullptr && primitive == prim::kPrimReturn) { + // set graph outputs tensors + auto inputNode = cnode->input(1); + if (!inputNode->isa()) { + continue; + } + auto inputCNode = utils::cast(inputNode); + auto inputPrimitive = GetValueNode(inputCNode->input(0)); + if (inputPrimitive == prim::kPrimMakeTuple) { + continue; + } else { + std::string inputName = inputNode->fullname_with_scope(); + auto graphOutput = nodeIdMap[inputName]; + metaGraphT->outputIndex.emplace_back(graphOutput); + } + continue; + } + if (primitive != nullptr && primitive == prim::kPrimMakeTuple) { + for (size_t i = 1; i < cnode->inputs().size(); i++) { + auto graphOutNode = cnode->input(i); + if (!graphOutNode->isa()) { + MS_LOG(ERROR) << "Inputs of MakeTuple should be cNode"; + return nullptr; + } + std::string graphOutNodeName = graphOutNode->fullname_with_scope(); + auto graphOutIndex = nodeIdMap[graphOutNodeName]; + metaGraphT->outputIndex.emplace_back(graphOutIndex); + } + continue; + } + + auto node = std::make_unique(); + node->name = cnode->fullname_with_scope(); + node->nodeType = schema::NodeType_CNode; + // populate primitive + if (primitive != nullptr) { + primitive = GetValueNode(cnode->input(0)); + MS_ASSERT(primitive != nullptr); + std::string opType = primitive->name(); + auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); + if (nodeParser == nullptr) { + MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; + return nullptr; + } + std::vector outputs; + nodeParser->Parse(cnode, node.get(), &outputs); + SetOpInputNode(cnode, metaGraphT.get(), node.get()); + SetOpOutputNode(outputs, metaGraphT.get(), node.get()); + metaGraphT->nodes.emplace_back(std::move(node)); + continue; + } + auto primitiveT_value = GetValueNode>(cnode->input(0)); + if (primitiveT_value == nullptr) { + MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; + return nullptr; + } + + auto *lite_primitive = primitiveT_value->GetPrimitiveT(); + if (lite_primitive == nullptr) { + MS_LOG(ERROR) << "Primitive in primitiveT_value is nullptr"; + return nullptr; + } + + node->primitive = std::unique_ptr(primitiveT_value->GetPrimitiveT()); + primitiveT_value->SetPrimitiveT(nullptr); + std::vector outputs; + SetOpInputNode(cnode, metaGraphT.get(), node.get()); + SetOpOutputNode(outputs, metaGraphT.get(), node.get()); + + // add quant param + node->quantType = primitiveT_value->GetQuantType(); + if (node->quantType == schema::QuantType_PostTraining) { + MS_LOG(INFO) << "node: " << node->name << " add QuantParam"; + // activation + auto activate_index = node->inputIndex[0]; + auto tensor_input = metaGraphT->allTensors[activate_index].get(); + auto input_quant_params = primitiveT_value->GetInputQuantParams(); + if (input_quant_params.empty()) { + MS_LOG(WARNING) << "node: " << node->name << " input quant params is empty"; + continue; + } + + std::unique_ptr input_quant_param = + std::make_unique(input_quant_params[0]); + tensor_input->quantParams.emplace_back(std::move(input_quant_param)); + // output + auto output_index = node->outputIndex[0]; + auto tensor_output = metaGraphT->allTensors[output_index].get(); + auto output_quant_params = primitiveT_value->GetOutputQuantParams(); + if (output_quant_params.empty()) { + MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; + continue; + } + + std::unique_ptr output_quant_param = + std::make_unique(output_quant_params[0]); + tensor_output->quantParams.emplace_back(std::move(output_quant_param)); + // // TensorType + // valuePtr = primitive->GetAttr(kInputTensorDataType); + // if (valuePtr != nullptr) { + // MS_LOG(INFO) << "node: " << node->name << " input tensor data type: " << GetValue(valuePtr); + // for (auto input : node->inputIndex) { + // auto tensor = subGraph->allTensors[input].get(); + // tensor->dataType = kNumberTypeUInt8; + // } + // } + } + + metaGraphT->nodes.emplace_back(std::move(node)); + } + // set graph input tensors + for (auto node : graphInputNodes) { + for (auto input : node->inputIndex) { + auto tensor = metaGraphT->allTensors[input].get(); + if (tensor->data.empty()) { + tensor->nodeType = schema::NodeType_ValueNode; + // tensor->refCount = lite::MSCONST_WEIGHT_REFCOUNT; + metaGraphT->inputIndex.emplace_back(input); + } + } + } + return metaGraphT.release(); +} + +void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode) { + MS_ASSERT(nullptr != meta_graph); + MS_ASSERT(nullptr != fbNode); + if (cnode->inputs().size() <= 1) { + return; + } + std::string cNodeName = cnode->fullname_with_scope(); + bool isGraphInput = true; + for (int i = 1; i < static_cast(cnode->inputs().size()); i++) { + auto inputNode = cnode->input(i); + if (inputNode->isa()) { + isGraphInput = false; + std::string inputName = inputNode->fullname_with_scope(); + if (nodeIdMap.find(inputName) != nodeIdMap.end()) { + fbNode->inputIndex.emplace_back(nodeIdMap[inputName]); + } + } else if (inputNode->isa()) { + auto paramNode = inputNode->cast(); + if (paramNode->name().empty()) { + paramNode->set_name(cNodeName + "_i:" + std::to_string(i - 1)); + } + if (nodeIdMap.find(paramNode->name()) != nodeIdMap.end()) { + fbNode->inputIndex.emplace_back(nodeIdMap[paramNode->name()]); + continue; + } + auto paramTensor = std::make_unique(); + auto abstractBase = paramNode->abstract(); + if (abstractBase == nullptr) { + MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name(); + MS_ASSERT(false); + return; + } + if (!utils::isa(abstractBase)) { + MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << paramNode->name(); + MS_ASSERT(false); + return; + } + auto abstractTensor = utils::cast(abstractBase); + auto typePtr = abstractTensor->element()->GetTypeTrack(); + MS_ASSERT(typePtr != nullptr); + paramTensor->dataType = typePtr->type_id(); + if (!utils::isa(abstractTensor->BuildShape())) { + MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name(); + MS_ASSERT(false); + return; + } + paramTensor->dims = utils::cast(abstractTensor->BuildShape())->shape(); + auto paramValue = std::dynamic_pointer_cast(paramNode->default_param()); + if (paramValue != nullptr) { + paramTensor->nodeType = schema::NodeType_ValueNode; + paramTensor->data.resize(paramValue->tensor_size()); + memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size()); + } + // for (auto &ite : paramValue->quant_param()) { + // auto quantPar = std::make_unique(); + // quantPar->scale = ite->scale; + // quantPar->zeroPoint = ite->zeroPoint; + // quantPar->min = ite->min; + // quantPar->max = ite->max; + // quantPar->narrowRange = ite->narrowRange; + // quantPar->inited = ite->inited; + // quantPar->numBits = ite->numBits; + // paramTensor->quantParams.emplace_back(std::move(quantPar)); + // } + nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size(); + fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); + meta_graph->allTensors.emplace_back(std::move(paramTensor)); + } + } + + if (isGraphInput) { + graphInputNodes.emplace_back(fbNode); + } +} + +void AnfExporter::SetOpOutputNode(const std::vector &outputTensors, schema::MetaGraphT *graph, + schema::CNodeT *cnode) { + MS_ASSERT(nullptr != graph); + MS_ASSERT(nullptr != cnode); + std::string cnodeName = cnode->name; + if (!outputTensors.empty()) { + int i = 0; + for (auto outputTensor : outputTensors) { + std::string name = cnodeName + "_o:" + std::to_string(i); + nodeIdMap[name] = graph->allTensors.size(); + cnode->outputIndex.emplace_back(graph->allTensors.size()); + graph->allTensors.emplace_back(outputTensor); + i++; + } + return; + } + auto msTensor = new schema::TensorT(); + msTensor->nodeType = schema::NodeType_Parameter; + cnode->outputIndex.emplace_back(graph->allTensors.size()); + nodeIdMap[cnodeName] = graph->allTensors.size(); + graph->allTensors.emplace_back(msTensor); +} + +schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph) { + AnfExporter anfExporter; + return anfExporter.Export(funcGraph); +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.h b/mindspore/lite/src/common/anf_exporter/anf_exporter.h new file mode 100644 index 00000000000..48d52fd4394 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.h @@ -0,0 +1,46 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ +#define MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ + +#include +#include +#include +#include "schema/inner/model_generated.h" +#include "ir/func_graph.h" + +namespace mindspore::lite { +class AnfExporter { + public: + AnfExporter() = default; + virtual ~AnfExporter() = default; + schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph); + void SetOpOutputNode(const std::vector &outputTensors, schema::MetaGraphT *graph, + schema::CNodeT *cnode); + void SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode); + + private: + std::map nodeIdMap; + std::vector graphInputNodes; +}; + +schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph); +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ + diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.cc new file mode 100644 index 00000000000..ab9e1b0cd6f --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_activation_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfActivationPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto p = GetCNodePrimitive(cnodePtr); + auto attr = std::make_unique(); + if (p->name() == "ReLU") { + attr->type = schema::ActivationType_RELU; + } else if (p->name() == "Sigmoid") { + attr->type = schema::ActivationType_SIGMOID; + } + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Activation; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfReLUParser("ReLU", new AnfActivationPopulater()); +AnfNodePopulaterRegistrar anfSigmoidParser("Sigmoid", new AnfActivationPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.h new file mode 100644 index 00000000000..daa4add19c0 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_activation_populater.h @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H +#define MINDSPORE_ANF_ACTIVATION_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfActivationPopulater : public AnfNodePopulater { + public: + AnfActivationPopulater() = default; + ~AnfActivationPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_ACTIVATION_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.cc new file mode 100644 index 00000000000..d8013aed14a --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfBatchnormParser::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto p = GetCNodePrimitive(cnodePtr); + auto attr = std::make_unique(); + attr->epsilon = GetValue(p->GetAttr("epsilon")); + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfBatchnormParser("BatchNorm", new AnfBatchnormParser()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h new file mode 100644 index 00000000000..1df83a87ac7 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H +#define MINDSPORE_ANF_BATCHNORM_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfBatchnormParser : public AnfNodePopulater { + public: + AnfBatchnormParser() = default; + ~AnfBatchnormParser() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_BATCHNORM_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.cc new file mode 100644 index 00000000000..ad59e89936a --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_biasadd_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfBiasAddPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto attr = std::make_unique(); + attr->axis = {0}; + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_BiasAdd; + node->primitive->value.value = attr.release(); + return 0; +} + +AnfNodePopulaterRegistrar anfBiasAddParser("BiasAdd", new AnfBiasAddPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.h new file mode 100644 index 00000000000..6256e20567f --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_biasadd_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_BIASADD_PARSER_H +#define MINDSPORE_ANF_BIASADD_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfBiasAddPopulater : public AnfNodePopulater { + public: + AnfBiasAddPopulater() = default; + ~AnfBiasAddPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_BIASADD_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.cc new file mode 100644 index 00000000000..c2dcee77ecb --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.cc @@ -0,0 +1,121 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/anf_exporter/anf_populater/anf_conv_populater.h" +#include +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto p = GetCNodePrimitive(cnodePtr); + int group = GetValue(p->GetAttr("group")); + + if (group > 1) { + auto attr = std::make_unique(); + auto format = GetValue(p->GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + auto pad_list = GetValue>(p->GetAttr("pad_list")); + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; + attr->padRight = pad_list[3]; + + auto dilation = GetValue>(p->GetAttr("dilation")); + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + auto kernel_size = GetValue>(p->GetAttr("kernel_size")); + attr->kernelH = kernel_size[0]; + attr->kernelW = kernel_size[1]; + + auto stride = GetValue>(p->GetAttr("stride")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + auto pad_mode = GetValue(p->GetAttr("pad_mode")); + if (pad_mode == "valid") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "same") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + node->primitive->value.value = attr.release(); + } else { + auto attr = std::make_unique(); + attr->group = group; + auto format = GetValue(p->GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + auto pad_list = GetValue>(p->GetAttr("pad_list")); + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; + attr->padRight = pad_list[3]; + + auto dilation = GetValue>(p->GetAttr("dilation")); + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + auto kernel_size = GetValue>(p->GetAttr("kernel_size")); + attr->kernelH = kernel_size[0]; + attr->kernelW = kernel_size[1]; + + auto stride = GetValue>(p->GetAttr("stride")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + attr->channelOut = GetValue(p->GetAttr("out_channel")); + + auto pad_mode = GetValue(p->GetAttr("pad_mode")); + if (pad_mode == "valid") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "same") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Conv2D; + node->primitive->value.value = attr.release(); + } + return 0; +} + +AnfNodePopulaterRegistrar anfConvParser("Conv2D", new AnfConvPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.h new file mode 100644 index 00000000000..88edda09510 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_conv_populater.h @@ -0,0 +1,32 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_ANF_CONV_PARSER_H +#define MINDSPORE_ANF_CONV_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfConvPopulater : public AnfNodePopulater { + public: + AnfConvPopulater() = default; + ~AnfConvPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_CONV_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.cc new file mode 100644 index 00000000000..8ba27f99a74 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_flatten_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfFlattenPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto attr = std::make_unique(); + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Flatten; + node->primitive->value.value = attr.release(); + return 0; +} + +AnfNodePopulaterRegistrar anfFlattenParser("Flatten", new AnfFlattenPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.h new file mode 100644 index 00000000000..f2cf48ab02c --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_flatten_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_FLATTEN_PARSER_H +#define MINDSPORE_ANF_FLATTEN_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfFlattenPopulater : public AnfNodePopulater { + public: + AnfFlattenPopulater() = default; + ~AnfFlattenPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_FLATTEN_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.cc new file mode 100644 index 00000000000..3bde1465959 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_matmul_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfMatmulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto p = GetCNodePrimitive(cnodePtr); + auto attr = std::make_unique(); + attr->transposeA = GetValue(p->GetAttr("transpose_a")); + attr->transposeB = GetValue(p->GetAttr("transpose_b")); + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_MatMul; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfMatmulParser("Matmul", new AnfMatmulPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.h new file mode 100644 index 00000000000..752e8eff314 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_matmul_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_MATMUL_PARSER_H +#define MINDSPORE_ANF_MATMUL_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfMatmulPopulater : public AnfNodePopulater { + public: + AnfMatmulPopulater() = default; + ~AnfMatmulPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_MATMUL_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.cc new file mode 100644 index 00000000000..4f5c3beec82 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_mul_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfMulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto attr = std::make_unique(); + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Mul; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfMulParser("Mul", new AnfMulPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.h new file mode 100644 index 00000000000..87f526cf7aa --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_mul_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H +#define MINDSPORE_ANF_ACTIVATION_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfMulPopulater : public AnfNodePopulater { + public: + AnfMulPopulater() = default; + ~AnfMulPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_ACTIVATION_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.cc new file mode 100644 index 00000000000..4045e0e0439 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.cc @@ -0,0 +1,19 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" + +namespace mindspore::lite {} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.h new file mode 100644 index 00000000000..e68645090fd --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater.h @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_ANF_NODE_PARSER_H +#define MINDSPORE_ANF_NODE_PARSER_H + +#include +#include "ir/anf.h" +#include "schema/inner/model_generated.h" +namespace mindspore::lite { +class AnfNodePopulater { + public: + AnfNodePopulater() = default; + virtual ~AnfNodePopulater() = default; + virtual int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) = 0; +}; + +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_NODE_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.cc new file mode 100644 index 00000000000..47a6582bcb6 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include +#include "src/common/anf_exporter/anf_populater/anf_biasadd_populater.h" +#include "src/common/anf_exporter/anf_populater/anf_conv_populater.h" +#include "src/common/anf_exporter/anf_populater/anf_matmul_populater.h" +#include "src/common/anf_exporter/anf_populater/anf_pool_populater.h" +#include "src/common/anf_exporter/anf_populater/anf_activation_populater.h" +#include "src/common/anf_exporter/anf_populater/anf_flatten_populater.h" +namespace mindspore { +namespace lite { +AnfNodePopulaterRegistry *AnfNodePopulaterRegistry::GetInstance() { + static AnfNodePopulaterRegistry instance; + instance.SetNodePopulater("BiasAdd", new AnfBiasAddPopulater()); + instance.SetNodePopulater("Conv2D", new AnfConvPopulater()); + instance.SetNodePopulater("MatMul", new AnfMatmulPopulater()); + instance.SetNodePopulater("MaxPool", new AnfPoolPopulater()); + instance.SetNodePopulater("ReLU", new AnfActivationPopulater()); + instance.SetNodePopulater("Flatten", new AnfFlattenPopulater()); + return &instance; +} +AnfNodePopulater *AnfNodePopulaterRegistry::GetNodePopulater(const std::string &name) { + if (parsers.find(name) == parsers.end()) { + return nullptr; + } + return parsers[name]; +} +void AnfNodePopulaterRegistry::SetNodePopulater(const std::string &name, AnfNodePopulater *parser) { + parsers[name] = parser; +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.h new file mode 100644 index 00000000000..f8271c14c0f --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_node_populater_registry.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_ANF_NODE_PARSER_REGISTRY_H +#define MINDSPORE_ANF_NODE_PARSER_REGISTRY_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +#include +namespace mindspore::lite { +class AnfNodePopulaterRegistry { + public: + AnfNodePopulaterRegistry() = default; + virtual ~AnfNodePopulaterRegistry() = default; + static AnfNodePopulaterRegistry *GetInstance(); + AnfNodePopulater *GetNodePopulater(const std::string &name); + void SetNodePopulater(const std::string &name, AnfNodePopulater *parser); + + private: + std::unordered_map parsers; +}; + +class AnfNodePopulaterRegistrar { + public: + AnfNodePopulaterRegistrar(const std::string &name, AnfNodePopulater *parser) { + AnfNodePopulaterRegistry::GetInstance()->SetNodePopulater(name, parser); + } +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_NODE_PARSER_REGISTRY_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.cc new file mode 100644 index 00000000000..8c70bb46aee --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_pool_populater.h" +#include +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfPoolPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto p = GetCNodePrimitive(cnodePtr); + auto attr = std::make_unique(); + if (p->instance_name() == "MaxPool") { + attr->poolingMode = schema::PoolMode_MAX_POOLING; + } else if (p->instance_name() == "MeanPool") { + attr->poolingMode = schema::PoolMode_MEAN_POOLING; + } + + auto format = GetValue(p->GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + + auto pad_mode = GetValue(p->GetAttr("padding")); + if (pad_mode == "VALID") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "SAME") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + + auto kernel_size = GetValue>(p->GetAttr("ksize")); + attr->windowH = kernel_size[2]; + attr->windowW = kernel_size[3]; + + auto stride = GetValue>(p->GetAttr("strides")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Pooling; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfMaxPoolParser("MaxPool", new AnfPoolPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.h new file mode 100644 index 00000000000..a677e7baca7 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_pool_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_POOL_PARSER_H +#define MINDSPORE_ANF_POOL_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfPoolPopulater : public AnfNodePopulater { + public: + AnfPoolPopulater() = default; + ~AnfPoolPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_POOL_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.cc new file mode 100644 index 00000000000..e7f5f71ff49 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_reducemean_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfReduceMeanPopulater::Parse(CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto p = GetCNodePrimitive(cnodePtr); + auto attr = std::make_unique(); + attr->mode = schema::ReduceMode_ReduceMean; + + attr->keepDims = GetValue(p->GetAttr("keep_dims")); + // attr->axes = GetValue>(p->GetAttr("shape")); + + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Reduce; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfReduceMeanParser("ReduceMean", new AnfReduceMeanPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.h new file mode 100644 index 00000000000..16ac3b0c7e4 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_reducemean_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H +#define MINDSPORE_ANF_ACTIVATION_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfReduceMeanPopulater : public AnfNodePopulater { + public: + AnfReduceMeanPopulater() = default; + ~AnfReduceMeanPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_ACTIVATION_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.cc new file mode 100644 index 00000000000..e220e45b410 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_tensoradd_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfTensorAddPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto attr = std::make_unique(); + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_Add; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfTensorAddParser("TensorAdd", new AnfTensorAddPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.h new file mode 100644 index 00000000000..d8ff59bba77 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tensoradd_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H +#define MINDSPORE_ANF_ACTIVATION_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfTensorAddPopulater : public AnfNodePopulater { + public: + AnfTensorAddPopulater() = default; + ~AnfTensorAddPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_ACTIVATION_PARSER_H diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.cc b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.cc new file mode 100644 index 00000000000..9f6092f4ae8 --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.h" +#include +#include +#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" + +namespace mindspore::lite { +int mindspore::lite::AnfTupleGetItemPopulater::Parse(CNodePtr cnodePtr, schema::CNodeT *node, + std::vector *outputs) { + auto attr = std::make_unique(); + node->nodeType = schema::NodeType_CNode; + node->primitive = std::make_unique(); + node->primitive->value.type = schema::PrimitiveType_TupleGetItem; + node->primitive->value.value = attr.release(); + return 0; +} +AnfNodePopulaterRegistrar anfTupleGetItemParser("tuple_getitem", new AnfTupleGetItemPopulater()); +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.h b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.h new file mode 100644 index 00000000000..3acf2638c3c --- /dev/null +++ b/mindspore/lite/src/common/anf_exporter/anf_populater/anf_tuple_getitem_populater.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H +#define MINDSPORE_ANF_BATCHNORM_PARSER_H +#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" +#include +namespace mindspore::lite { +class AnfTupleGetItemPopulater : public AnfNodePopulater { + public: + AnfTupleGetItemPopulater() = default; + ~AnfTupleGetItemPopulater() override = default; + int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector *outputs) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_ANF_BATCHNORM_PARSER_H diff --git a/mindspore/lite/src/common/anf_importer/anf_importer.cc b/mindspore/lite/src/common/anf_importer/anf_importer.cc new file mode 100644 index 00000000000..eb9f84eca33 --- /dev/null +++ b/mindspore/lite/src/common/anf_importer/anf_importer.cc @@ -0,0 +1,184 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "src/common/anf_importer/anf_importer.h" +#include "schema/model_generated.h" +#include "ir/dtype.h" +#include "ir/primitive.h" +#include "src/param_value_lite.h" +#include "frontend/operator/ops.h" +#include "abstract/abstract_value.h" +#include "src/ir/primitive_value.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +#if 0 +PrimitivePtr SetConv2DAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto attrs = cNode->primitive()->value_as_Conv2D(); + PrimitivePtr prim; + if (attrs->group() > 1) { + prim = std::make_shared("DepthwiseConv2D"); + prim->set_instance_name("DepthwiseConv2D"); + } else { + prim = std::make_shared("Conv2D"); + prim->set_instance_name("Conv2D"); + } + + prim->set_attr("group", MakeValue(attrs->group())); + prim->set_attr("format", MakeValue(attrs->format())); + prim->set_attr("pad_mode", MakeValue(attrs->padMode())); + std::vector pad_list = {attrs->padUp(), attrs->padDown(), attrs->padLeft(), attrs->padRight()}; + prim->set_attr("pad_list", MakeValue>(pad_list)); + std::vector dilate = {attrs->dilateH(), attrs->dilateW()}; + prim->set_attr("dilation", MakeValue>(dilate)); + std::vector kernel_size = {attrs->kernelH(), attrs->kernelW()}; + prim->set_attr("kernel_size", MakeValue>(kernel_size)); + std::vector stride = {1, 1, attrs->strideH(), attrs->strideW()}; + prim->set_attr("stride", MakeValue>(stride)); + prim->set_attr("out_channel", MakeValue(attrs->channelOut())); + prim->set_attr("group", MakeValue(attrs->group())); + return prim; +} + +PrimitivePtr SetActivationAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto attrs = cNode->primitive()->value_as_Activation(); + PrimitivePtr prim; + if (attrs->type() == schema::ActivationType_RELU) { + prim = std::make_shared("ReLU"); + prim->set_instance_name("ReLU"); + } + return prim; +} + +PrimitivePtr SetPoolingAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto attrs = cNode->primitive()->value_as_Pooling(); + PrimitivePtr prim; + if (attrs->poolingMode() == schema::PoolMode_MAX_POOLING) { + prim = std::make_shared("MaxPool"); + prim->set_instance_name("MaxPool"); + } else if (attrs->poolingMode() == schema::PoolMode_MEAN_POOLING) { + prim = std::make_shared("MeanPool"); + prim->set_instance_name("MeanPool"); + } + + prim->set_attr("format", MakeValue(attrs->format())); + prim->set_attr("pad_mode", MakeValue(attrs->padMode())); + prim->set_attr("ksize", MakeValue>(std::vector({1, 1, attrs->windowH(), attrs->windowW()}))); + prim->set_attr("strides", MakeValue>(std::vector({1, 1, attrs->strideH(), attrs->strideW()}))); + return prim; +} + +PrimitivePtr SetFlattenAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto prim = std::make_shared("Flatten"); + prim->set_instance_name("Flatten"); + return prim; +} + +PrimitivePtr SetMatmulAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto attrs = cNode->primitive()->value_as_MatMul(); + auto prim = std::make_shared("Matmul"); + prim->set_instance_name("Matmul"); + prim->set_attr("transpose_a", MakeValue(attrs->transposeA())); + prim->set_attr("transpose_b", MakeValue(attrs->transposeB())); + return prim; +} + +PrimitivePtr SetMulAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + // auto attrs = nodedef->attr_as_Mul(); + auto prim = std::make_shared("Mul"); + prim->set_instance_name("Mul"); + return prim; +} + +PrimitivePtr SetSigmoidAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto prim = std::make_shared("Sigmoid"); + prim->set_instance_name("Sigmoid"); + return prim; +} + +PrimitivePtr SetReduceAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto prim = std::make_shared("ReduceMean"); + prim->set_instance_name("ReduceMean"); + return prim; +} + +PrimitivePtr SetBatchNormAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto attrs = cNode->primitive_as_BatchNorm(); + auto prim = std::make_shared("BatchNorm"); + prim->set_attr("is_training", MakeValue(attrs->is_training())); + prim->set_instance_name("BatchNorm"); + return prim; +} + +PrimitivePtr SetBiasAddAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto prim = std::make_shared("BiasAdd"); + prim->set_instance_name("BiasAdd"); + return prim; +} + +PrimitivePtr SetAddAttr(const schema::CNode *cNode) { + MS_EXCEPTION_IF_NULL(cNode); + auto prim = std::make_shared("Add"); + prim->set_instance_name("Add"); + return prim; +} + +void MinnieBuildGraph::FbTest(const GraphDef *graph_def) { + auto node_def = graph_def->subgraphs()->begin()->nodes()->GetAs(3); + PrimitivePtr prim = ConverterOperatorAttr(node_def); + if (prim->GetAttr("format")) MS_LOG(INFO) << "find format"; + if (prim->GetAttr("group")) MS_LOG(INFO) << "find group"; +} +#endif + +int AnfImporter::Import() { + ConverterConstTensor(); + auto ret = ConverterCNode(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "ConverterCNode failed " << ret; + return ret; + } + AddReturnCNode(); + return RET_OK; +} + +AnfNodePtr AnfImporter::GetNode(int tensor_id) { + auto n = nodes_.find(tensor_id); + if (n == nodes_.end()) { + return nullptr; + } + return n->second; +} + +void AnfImporter::AddNode(int tensor_id, AnfNodePtr node) { nodes_[tensor_id] = std::move(node); } +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/src/common/anf_importer/anf_importer.h b/mindspore/lite/src/common/anf_importer/anf_importer.h new file mode 100644 index 00000000000..3281294f409 --- /dev/null +++ b/mindspore/lite/src/common/anf_importer/anf_importer.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_ANF_IMPORTER_H_ +#define MINDSPORE_LITE_SRC_ANF_IMPORTER_ANF_IMPORTER_H_ + +#include +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "base/base.h" + +namespace mindspore::lite { +class AnfImporter { + public: + AnfImporter() = default; + + virtual ~AnfImporter() = default; + + virtual int Import(); + + virtual FuncGraphPtr GetResult() = 0; + + protected: + // convert const tensor into parameter and save in nodes_ + virtual void ConverterConstTensor() = 0; + // convert other node into cnode and save in nodes_ + virtual int ConverterCNode() = 0; + + virtual void AddReturnCNode() = 0; + + AnfNodePtr GetNode(int tensor_id); + + void AddNode(int tensor_id, AnfNodePtr node); + + protected: + std::unordered_map nodes_; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_ANF_IMPORTER_H_ + diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graph.cc b/mindspore/lite/src/common/anf_importer/import_from_meta_graph.cc new file mode 100644 index 00000000000..6ec0ba8c545 --- /dev/null +++ b/mindspore/lite/src/common/anf_importer/import_from_meta_graph.cc @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/anf_importer/import_from_meta_graph.h" +#include +#include +#include +#include "frontend/operator/ops.h" +#include "src/param_value_lite.h" +#include "utils/log_adapter.h" +#include "abstract/abstract_value.h" +#include "src/ir/primitive_value.h" +#include "include/errorcode.h" + +namespace mindspore::lite { +void AnfImporterFromMetaGraph::ConverterConstTensor() { + MS_EXCEPTION_IF_NULL(model); + auto *meta_graph = model->GetMetaGraph(); + MS_EXCEPTION_IF_NULL(meta_graph); + for (size_t i = 0; i < meta_graph->allTensors()->size(); i++) { + auto *tensor = meta_graph->allTensors()->GetAs(i); + MS_EXCEPTION_IF_NULL(tensor); + if (tensor->nodeType() != schema::NodeType_ValueNode) { + continue; + } + MS_ASSERT(tensor->dims() != nullptr); + auto parameter = model->add_parameter(); + std::vector shape; + for (size_t j = 0; j < tensor->dims()->size(); ++j) { + shape.push_back(tensor->dims()->data()[j]); + } + auto type_id = static_cast(tensor->dataType()); + auto type_ptr = TypeIdToType(type_id); + auto abstract_tensor = std::make_shared(type_ptr, shape); + parameter->set_abstract(abstract_tensor); + + ParamValueLitePtr param_value = std::make_shared(); + MS_EXCEPTION_IF_NULL(param_value); + param_value->set_tensor_shape(shape); + param_value->set_tensor_type(type_id); + if (tensor->data() != nullptr) { + auto size = tensor->data()->size(); + char *tensor_data = new char[size](); + std::memcpy(tensor_data, tensor->data()->data(), size); + MS_EXCEPTION_IF_NULL(tensor_data); + param_value->set_tensor_addr(tensor_data); + param_value->set_tensor_size(size); + } + parameter->set_default_param(param_value); + AddNode(i, parameter); + } +} + +int AnfImporterFromMetaGraph::ConverterCNode() { + MS_EXCEPTION_IF_NULL(model); + auto *meta_graph = model->GetMetaGraph(); + MS_EXCEPTION_IF_NULL(meta_graph); + auto cNodes = meta_graph->nodes(); + for (size_t i = 0; i < cNodes->size(); i++) { + auto cNode = cNodes->GetAs(i); + MS_EXCEPTION_IF_NULL(cNode); + auto tensor_id = cNode->outputIndex()->data()[0]; + if (GetNode(tensor_id)) { + continue; + } + + auto prim = std::make_shared(model->GetOp(cNode->name()->str())); + if (prim == nullptr) { + MS_LOG(ERROR) << "th tensorDef in subGraphDef is nullptr"; + return RET_ERROR; + } + auto value_node = NewValueNode(prim); + AddNode(tensor_id, value_node); + + std::vector op_inputs = {value_node}; + MS_EXCEPTION_IF_NULL(cNode->inputIndex()); + for (size_t j = 0; j < cNode->inputIndex()->size(); j++) { + auto node = GetNode(*(cNode->inputIndex()->GetAs(j))); + if (nullptr == node) { + MS_LOG(ERROR) << "Can't find input node."; + return RET_ERROR; + } + // todo: CheckInputNodeType, the first node should be op; + op_inputs.push_back(node); + } + auto cnode = model->NewCNode(op_inputs); + auto node_name = std::string(cNode->name()->c_str()); + cnode->set_fullname_with_scope(node_name); + AddNode(tensor_id, cnode); + } + return RET_OK; +} + +void AnfImporterFromMetaGraph::AddReturnCNode() { + MS_EXCEPTION_IF_NULL(model); + auto *meta_graph = model->GetMetaGraph(); + MS_EXCEPTION_IF_NULL(meta_graph); + std::vector op_inputs; + auto value_node = NewValueNode(prim::kPrimReturn); + op_inputs.push_back(value_node); + auto tensor_id = meta_graph->outputIndex()->data()[0]; + op_inputs.push_back(GetNode(tensor_id)); + auto cnode = model->NewCNode(op_inputs); + cnode->set_fullname_with_scope("return"); + model->set_return(cnode); +} +FuncGraphPtr AnfImporterFromMetaGraph::GetResult() { return this->model; } +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graph.h b/mindspore/lite/src/common/anf_importer/import_from_meta_graph.h new file mode 100644 index 00000000000..fd34930f1cd --- /dev/null +++ b/mindspore/lite/src/common/anf_importer/import_from_meta_graph.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ +#define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ + +#include +#include "src/train/model_impl.h" +#include "schema/model_generated.h" +#include "src/common/anf_importer/anf_importer.h" + +namespace mindspore::lite { +class AnfImporterFromMetaGraph : public AnfImporter { + public: + explicit AnfImporterFromMetaGraph(std::shared_ptr model) : model(model) {} + + ~AnfImporterFromMetaGraph() override = default; + + FuncGraphPtr GetResult() override; + + private: + void ConverterConstTensor() override; + + int ConverterCNode() override; + + void AddReturnCNode() override; + + private: + std::shared_ptr model = nullptr; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ + diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc b/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc new file mode 100644 index 00000000000..c470d6a6e30 --- /dev/null +++ b/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc @@ -0,0 +1,123 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "schema/inner/model_generated.h" +#include "frontend/operator/ops.h" +#include "src/param_value_lite.h" +#include "import_from_meta_graphT.h" +#include "utils/log_adapter.h" +#include "abstract/abstract_value.h" +#include "src/ir/primitive_value.h" +#include "src/ir/primitive_t_value.h" +#include "include/errorcode.h" +#include "src/ops/ops.h" + +namespace mindspore::lite { +void AnfImporterFromMetaGraphT::ConverterConstTensor() { + MS_EXCEPTION_IF_NULL(meta_graph_); + MS_EXCEPTION_IF_NULL(func_graph_); + for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) { + auto &tensor = meta_graph_->allTensors.at(i); + MS_EXCEPTION_IF_NULL(tensor); + if (tensor->nodeType != schema::NodeType_ValueNode) { + continue; + } + MS_ASSERT(tensor->dims() != nullptr); + auto parameter = func_graph_->add_parameter(); + std::vector shape; + for (int &dim : tensor->dims) { + shape.push_back(dim); + } + auto type_id = static_cast(tensor->dataType); + auto type_ptr = TypeIdToType(type_id); + auto abstract_tensor = std::make_shared(type_ptr, shape); + parameter->set_abstract(abstract_tensor); + + ParamValueLitePtr param_value = std::make_shared(); + MS_EXCEPTION_IF_NULL(param_value); + param_value->set_tensor_shape(shape); + param_value->set_tensor_type(type_id); + if (!tensor->data.empty()) { + auto size = tensor->data.size(); + char *tensor_data = new char[size]; + std::memcpy(tensor_data, tensor->data.data(), size); + MS_EXCEPTION_IF_NULL(tensor_data); + param_value->set_tensor_addr(tensor_data); + param_value->set_tensor_size(size); + } + parameter->set_default_param(param_value); + AddNode(i, parameter); + } +} + +int AnfImporterFromMetaGraphT::ConverterCNode() { + MS_EXCEPTION_IF_NULL(meta_graph_); + MS_EXCEPTION_IF_NULL(func_graph_); + for (size_t i = 0; i < meta_graph_->nodes.size(); i++) { + auto &cNode = meta_graph_->nodes.at(i); + MS_EXCEPTION_IF_NULL(cNode); + auto tensor_id = cNode->outputIndex.front(); + if (nullptr != GetNode(tensor_id)) { + continue; + } + + auto primTValue = std::make_shared(cNode->primitive.release()); + cNode->primitive = nullptr; + auto value_node = NewValueNode(primTValue); + + std::vector op_inputs = {value_node}; + for (size_t j = 0; j < cNode->inputIndex.size(); j++) { + auto node = GetNode(cNode->inputIndex.at(j)); + if (nullptr == node) { + MS_LOG(ERROR) << "Can't find input node."; + return RET_ERROR; + } + // todo: CheckInputNodeType, the first node should be op; + op_inputs.push_back(node); + } + auto cnode = func_graph_->NewCNode(op_inputs); + cnode->set_fullname_with_scope(cNode->name); + AddNode(tensor_id, cnode); + } + return RET_OK; +} + +void AnfImporterFromMetaGraphT::AddReturnCNode() { + MS_EXCEPTION_IF_NULL(meta_graph_); + MS_EXCEPTION_IF_NULL(func_graph_); + std::vector make_tuple_inputs; + auto make_tuple_value_node = NewValueNode(prim::kPrimMakeTuple); + make_tuple_inputs.emplace_back(make_tuple_value_node); + for (auto tensor_id : meta_graph_->outputIndex) { + make_tuple_inputs.emplace_back(GetNode(tensor_id)); + } + auto make_tuple_cnode = func_graph_->NewCNode(make_tuple_inputs); + make_tuple_cnode->set_fullname_with_scope("return tuple"); + + std::vector op_inputs; + auto value_node = NewValueNode(prim::kPrimReturn); + op_inputs.emplace_back(value_node); + op_inputs.emplace_back(make_tuple_cnode); + auto cnode = func_graph_->NewCNode(op_inputs); + cnode->set_fullname_with_scope("return"); + func_graph_->set_return(cnode); +} + +FuncGraphPtr AnfImporterFromMetaGraphT::GetResult() { return this->func_graph_; } +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.h b/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.h new file mode 100644 index 00000000000..5b3799a2560 --- /dev/null +++ b/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ +#define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ + +#include + +#include "schema/inner/model_generated.h" +#include "src/common/anf_importer/anf_importer.h" + +namespace mindspore::lite { +class AnfImporterFromMetaGraphT : public AnfImporter { + public: + explicit AnfImporterFromMetaGraphT(schema::MetaGraphT *meta_graph, FuncGraphPtr func_graph) + : meta_graph_(meta_graph), func_graph_(std::move(func_graph)) {} + + ~AnfImporterFromMetaGraphT() override = default; + + FuncGraphPtr GetResult() override; + + private: + void ConverterConstTensor() override; + + int ConverterCNode() override; + + void AddReturnCNode() override; + + private: + schema::MetaGraphT *meta_graph_; + FuncGraphPtr func_graph_; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ + diff --git a/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc b/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc new file mode 100644 index 00000000000..f4505fbfb1e --- /dev/null +++ b/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc @@ -0,0 +1,717 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/anf_importer/import_from_protobuf.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "src/param_value_lite.h" +#include "src/ir/tensor.h" +#include "frontend/operator/ops.h" +#include "tools/converter/parser/onnx/onnx.pb.h" +#include "utils/log_adapter.h" +#include "include/errorcode.h" + +using string = std::string; +using int32 = int32_t; +using int64 = int64_t; +using uint64 = uint64_t; + +namespace mindspore::lite { + +static constexpr char kConstantValueNode[] = "Constant"; +static constexpr char kCNodeShapeAttr[] = "shape"; +static constexpr char kCNodeShape1Attr[] = "shape1"; +static constexpr char kCNodeShape2Attr[] = "shape2"; + +enum ParseForm : int { + FORM_PARSE_TYPE = 0, + FORM_PARSE_SCALAR = 1, + FORM_PARSE_TENSOR = 2, +}; + +static std::map kParseTypeSwitchMap{ + {"type", FORM_PARSE_TYPE}, + {"scalar", FORM_PARSE_SCALAR}, + {"tensor", FORM_PARSE_TENSOR}}; + +static std::unordered_map kDefaultValueSwitchMap{ + {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, + {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, + {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, + {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, + {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, + {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, + {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, + {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, + {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, + {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, + {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, + {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, + {onnx::TensorProto_DataType_STRING, kObjectTypeString}, +}; + +std::shared_ptr ParserScalarAttrValue(const std::string &attr_name, + const std::unordered_map &kv) { + std::string str = attr_name; + auto replace = [&](const string &orgStr, const string &newStr) { + std::string::size_type pos(0); + while ((pos = str.find(orgStr)) != std::string::npos) { + str.replace(pos, orgStr.length(), newStr); + } + return str; + }; + // remove "scalar:" + str = replace("scalar:", ""); + // remove "Tuple" + str = replace("Tuple", ""); + // remove "List" + str = replace("List", ""); + std::stack rules; + std::stack value; + int num = 0, count = 0; + for (size_t i = 0; i < str.length(); i++) { + if (str[i] == '[') { + rules.push("["); + } else if (str[i] == ']') { + // rules + std::vector vec; + while (rules.top() != "[") { + rules.pop(); + vec.push_back(value.top()); + value.pop(); + } + // pop "[" + rules.pop(); + // make tuple for names + std::string res = "dummy"; + // make tuple for values + reverse(vec.begin(), vec.end()); + auto vt = std::make_shared(vec); + if (rules.empty() && value.empty()) { + return vt; + } + rules.push(res); + value.push(vt); + } else if (str[i] == ',') { + continue; + } else { + count++; + if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') { + auto value_name = str.substr(i - count + 1, count); + value.push(kv.at(value_name)); + rules.push(value_name); + count = 0; + num++; + } + } + } + return {}; +} + +std::shared_ptr +ParserAttrShape(const std::string &attr_name, const std::unordered_map &kv) { + std::string str = attr_name; + auto replace = [&](const string &orgStr, const string &newStr) { + std::string::size_type pos(0); + while ((pos = str.find(orgStr)) != std::string::npos) { + str.replace(pos, orgStr.length(), newStr); + } + return str; + }; + // remove "scalar:" + str = replace("shape:", ""); + // remove "Tuple" + str = replace("Tuple", ""); + // remove "List" + str = replace("List", ""); + std::stack rules; + std::stack value; + int num = 0, count = 0; + for (size_t i = 0; i < str.length(); i++) { + if (str[i] == '[') { + rules.push("["); + } else if (str[i] == ']') { + // rules + std::vector vec; + while (rules.top() != "[") { + rules.pop(); + vec.push_back(value.top()); + value.pop(); + } + // pop "[" + rules.pop(); + // make tuple for names + std::string res = "dummy"; + // make tuple for values + reverse(vec.begin(), vec.end()); + auto vt = std::make_shared(vec); + if (rules.empty() && value.empty()) { + return vt; + } + rules.push(res); + value.push(vt); + } else if (str[i] == ',') { + continue; + } else { + count++; + if (str[i + 1] == '[' || str[i + 1] == ']' || str[i + 1] == ',') { + auto value_name = str.substr(i - count + 1, count); + value.push(kv.at(value_name)); + rules.push(value_name); + count = 0; + num++; + } + } + } + return {}; +} + +#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ +ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \ +if (attr_tensor.type##_data_size() == 1) { \ + auto value = static_cast(attr_tensor.type##_data(0)); \ + return MakeValue(value); \ +} else { \ + MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ +} \ +return{}; \ +} + +PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) +PARSE_ONNXATTR_IN_SCALAR_FORM(float, float) +PARSE_ONNXATTR_IN_SCALAR_FORM(string, string) +PARSE_ONNXATTR_IN_SCALAR_FORM(int32, int32) +PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool) +PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) +PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) + +bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node, + const onnx::ValueInfoProto &value_proto) { + MS_EXCEPTION_IF_NULL(node); + if (!value_proto.has_type() || !value_proto.has_name()) { + MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! "; + return false; + } + node->set_name(value_proto.name()); + const auto &type_proto = value_proto.type(); + if (!type_proto.has_tensor_type()) { + MS_LOG(ERROR) << "onnx TypeProto has no tesor_type! "; + return false; + } + const onnx::TypeProto_Tensor &tensor_typeproto = type_proto.tensor_type(); + if (!tensor_typeproto.has_elem_type() || !tensor_typeproto.has_shape()) { + MS_LOG(ERROR) << "onnx TypeProto_Tensor has no elem_type or shape! "; + return false; + } + const onnx::TensorShapeProto &tensor_shape = tensor_typeproto.shape(); + std::vector shape; + for (int i = 0; i < tensor_shape.dim_size(); ++i) { + shape.push_back(tensor_shape.dim(i).dim_value()); + } + + if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!"; + return false; + } + + auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]); + auto abstract_tensor = std::make_shared(type_ptr, shape); + node->set_abstract(abstract_tensor); + + if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { + tensor::Tensor *tensor_info = new tensor::Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); + MS_EXCEPTION_IF_NULL(tensor_info); + tensor_info->MallocData(); + const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()]; + std::string initial_data = initialize_proto.raw_data(); + auto *tensor_data_buf = reinterpret_cast(tensor_info->Data()); + MS_EXCEPTION_IF_NULL(tensor_data_buf); + memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size()); + + ParamValueLitePtr param_value = std::make_shared(); + MS_EXCEPTION_IF_NULL(param_value); + param_value->set_tensor_addr(tensor_data_buf); + param_value->set_tensor_size(tensor_info->Size()); + node->set_default_param(param_value); + } + anfnode_build_map_[value_proto.name()] = node; + return true; +} + +bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size(); + + for (int i = 0; i < importProto.initializer_size(); ++i) { + const onnx::TensorProto &initializer_proto = importProto.initializer(i); + if (!initializer_proto.has_name()) { + MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i; + return false; + } + default_para_map_[initializer_proto.name()] = initializer_proto; + } + + MS_LOG(INFO) << "all parameters size: " << importProto.input_size(); + for (int i = 0; i < importProto.input_size(); ++i) { + const onnx::ValueInfoProto &input_proto = importProto.input(i); + if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) { + MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; + return false; + } + } + return true; +} + +bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { + MS_EXCEPTION_IF_NULL(prim); + const int attr_tensor_type = attr_tensor.data_type(); + if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type; + return false; + } + prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); + return true; +} + +ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor) { + const int attr_tensor_type = attr_tensor.data_type(); + switch (attr_tensor_type) { + case onnx::TensorProto_DataType_STRING: { + return ParseAttrInScalar_string_string(attr_tensor); + } + case onnx::TensorProto_DataType_INT32: { + return ParseAttrInScalar_int32_int32(attr_tensor); + } + case onnx::TensorProto_DataType_INT64: { + return ParseAttrInScalar_int64_int64(attr_tensor); + } + case onnx::TensorProto_DataType_UINT64: { + return ParseAttrInScalar_uint64_uint64(attr_tensor); + } + case onnx::TensorProto_DataType_FLOAT: { + return ParseAttrInScalar_float_float(attr_tensor); + } + case onnx::TensorProto_DataType_DOUBLE: { + return ParseAttrInScalar_double_double(attr_tensor); + } + case onnx::TensorProto_DataType_BOOL: { + return ParseAttrInScalar_int32_bool(attr_tensor); + } + default: + MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; + return {}; + } + return {}; +} + +bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { + MS_EXCEPTION_IF_NULL(prim); + MS_LOG(ERROR) << "parse attr type don't support attr type is tensor"; + return false; +} + +bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { + MS_EXCEPTION_IF_NULL(prim); + const std::string &attr_name = attr_proto.name(); + if (!attr_proto.has_ref_attr_name()) { + MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; + return false; + } + const std::string &ref_attr_name = attr_proto.ref_attr_name(); + string type; + std::size_t pos(0); + if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("scalar:").length() - 1); + } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("type:").length() - 1); + } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("tensor:").length() - 1); + } + std::unordered_map kv; + for (int i = 0; i < attr_proto.tensors_size(); i++) { + const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); + switch (kParseTypeSwitchMap[type]) { + case FORM_PARSE_TYPE: { + return ObtainCNodeAttrInTypeForm(prim, attr_name, attr_tensor); + } + case FORM_PARSE_SCALAR: { + auto res = ObtainCNodeAttrInScalarForm(attr_tensor); + kv.insert(std::pair(attr_tensor.name(), res)); + break; + } + case FORM_PARSE_TENSOR: { + return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor); + } + default: + MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; + return false; + } + } + if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { + if (kv.size() == 1) { + std::unordered_map::iterator iter = kv.begin(); + prim->AddAttr(attr_name, iter->second); + } else { + auto res = ParserScalarAttrValue(ref_attr_name, kv); + prim->AddAttr(attr_name, res); + } + } + return true; +} + +bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name, + const onnx::TensorProto &attr_tensor) { + const int attr_tensor_type = attr_tensor.data_type(); + std::vector shape; + for (int i = 0; i < attr_tensor.dims_size(); ++i) { + shape.push_back(attr_tensor.dims(i)); + } + tensor::TensorPtr tensor_info = std::make_shared(kDefaultValueSwitchMap[attr_tensor_type], shape); + tensor_info->MallocData(); + const std::string &tensor_buf = attr_tensor.raw_data(); + auto *tensor_data_buf = reinterpret_cast(tensor_info->Data()); + memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size()); + auto new_value_node = NewValueNode(MakeValue(tensor_info)); + MS_EXCEPTION_IF_NULL(new_value_node); + auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]); + auto abstract_tensor = std::make_shared(type_ptr, shape); + new_value_node->set_abstract(abstract_tensor); + anfnode_build_map_[value_node_name] = new_value_node; + return true; +} + +bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name, + const onnx::TensorProto &attr_tensor) { + const int attr_tensor_type = attr_tensor.data_type(); + if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; + return false; + } + auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); + abstract::AbstractTypePtr abs_type = std::make_shared(std::make_shared()); + new_value_node->set_abstract(abs_type); + anfnode_build_map_[value_node_name] = new_value_node; + return true; +} + +bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_node_name, + const onnx::AttributeProto &attr_proto) { + const std::string &attr_name = attr_proto.name(); + if (!attr_proto.has_ref_attr_name()) { + MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name"; + return false; + } + const std::string &ref_attr_name = attr_proto.ref_attr_name(); + string type; + std::size_t pos(0); + if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("scalar:").length() - 1); + } else if ((pos = ref_attr_name.find("type:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("type:").length() - 1); + } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("tensor:").length() - 1); + } + std::unordered_map kv; + for (int i = 0; i < attr_proto.tensors_size(); i++) { + const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); + switch (kParseTypeSwitchMap[type]) { + case FORM_PARSE_TYPE: { + return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); + } + case FORM_PARSE_SCALAR: { + auto res = ObtainCNodeAttrInScalarForm(attr_tensor); + kv.insert(std::pair(attr_tensor.name(), res)); + break; + } + case FORM_PARSE_TENSOR: { + return ObtainValueNodeInTensorForm(value_node_name, attr_tensor); + } + default: + MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; + return false; + } + } + + ValueNodePtr new_value_node; + if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR) { + if (kv.size() == 1) { + std::unordered_map::iterator iter = kv.begin(); + new_value_node = NewValueNode(iter->second); + new_value_node->set_abstract(iter->second->ToAbstract()); + } else { + auto value_ptr = ParserScalarAttrValue(ref_attr_name, kv); + new_value_node = NewValueNode(value_ptr); + new_value_node->set_abstract(value_ptr->ToAbstract()); + } + anfnode_build_map_[value_node_name] = new_value_node; + } + return true; +} + +bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { + const std::string &value_node_name = node_proto.output(0); + const onnx::AttributeProto &attr_proto = node_proto.attribute(0); + if (!attr_proto.has_ref_attr_name()) { + MS_LOG(ERROR) << "parse ValueNode don't have ref_attr_name"; + return false; + } + return GetAttrValueForValueNode(value_node_name, attr_proto); +} + +std::unordered_map +AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) { + std::unordered_map kv; + for (int i = 0; i < attr_proto.tensors_size(); i++) { + std::vector shape_vec; + const onnx::TensorProto &attr_tensor = attr_proto.tensors(i); + for (int j = 0; j < attr_tensor.dims_size(); ++j) { + shape_vec.push_back(attr_tensor.dims(j)); + } + auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); + auto abstract_tensor = std::make_shared(type_ptr, shape_vec); + kv.insert(std::pair(attr_tensor.name(), abstract_tensor)); + } + return kv; +} + +CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::NodeProto &node_proto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + if (!node_proto.has_op_type()) { + MS_LOG(ERROR) << "Get CNode op_type failed!"; + return nullptr; + } + const std::string &node_name = node_proto.output(0); + const std::string &fullname_with_scope = node_proto.domain(); + const std::string &node_type = node_proto.op_type(); + PrimitivePtr prim = std::make_shared(node_type); + MS_EXCEPTION_IF_NULL(prim); + prim->set_instance_name(node_type); + std::unordered_map kv; + string shape_ref_attr_name; + for (int i = 0; i < node_proto.attribute_size(); ++i) { + const onnx::AttributeProto &attr_proto = node_proto.attribute(i); + if (attr_proto.ref_attr_name().find("shape:") != string::npos) { + shape_ref_attr_name = attr_proto.ref_attr_name(); + kv = GetAbstractForCNode(attr_proto); + continue; + } + if (!GetAttrValueForCNode(prim, attr_proto)) { + MS_LOG(ERROR) << "Get CNode attr failed!"; + return nullptr; + } + } + + std::vector inputs; + inputs.clear(); + inputs.push_back(NewValueNode(prim)); + for (int i = 0; i < node_proto.input_size(); ++i) { + const std::string &input_name = node_proto.input(i); + if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { + MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; + return nullptr; + } + inputs.push_back(anfnode_build_map_[input_name]); + } + CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(cnode_ptr); + if (0 == kv.size()) { + AbstractBasePtrList elem; + for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { + elem.push_back(cnode_ptr->input(index)->abstract()); + } + cnode_ptr->set_abstract(std::make_shared(elem)); + } else if (1 == kv.size()) { + std::unordered_map::iterator iter = kv.begin(); + cnode_ptr->set_abstract(iter->second); + } else { + auto abstract = ParserAttrShape(shape_ref_attr_name, kv); + cnode_ptr->set_abstract(abstract); + } + + cnode_ptr->set_fullname_with_scope(fullname_with_scope); + anfnode_build_map_[node_name] = cnode_ptr; + return cnode_ptr; +} + +bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto, + const CNodePtr &cnode_ptr) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + MS_EXCEPTION_IF_NULL(cnode_ptr); + std::vector inputs; + if (importProto.output_size() > 1) { + inputs.clear(); + inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + AbstractBasePtrList elem; + for (int out_size = 0; out_size < importProto.output_size(); ++out_size) { + const onnx::ValueInfoProto &output_node = importProto.output(out_size); + const std::string &out_tuple = output_node.name(); + inputs.push_back(anfnode_build_map_[out_tuple]); + elem.push_back(anfnode_build_map_[out_tuple]->abstract()); + } + auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); + maketuple_ptr->set_abstract(std::make_shared(elem)); + inputs.clear(); + inputs.push_back(NewValueNode(prim::kPrimReturn)); + inputs.push_back(maketuple_ptr); + auto return_node = outputFuncGraph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(return_node); + outputFuncGraph->set_return(return_node); + MS_LOG(INFO) << "Construct funcgraph finined, all success."; + } else { + const onnx::ValueInfoProto &output_node = importProto.output(0); + const onnx::TypeProto &output_typeproto = output_node.type(); + int output_type = output_typeproto.tensor_type().elem_type(); + std::vector output_shape; + for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) { + output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value()); + } + auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); + auto abstract_tensor = std::make_shared(type_ptr, output_shape); + + inputs.clear(); + inputs.push_back(NewValueNode(prim::kPrimReturn)); + inputs.push_back(cnode_ptr); + auto return_node = outputFuncGraph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(return_node); + return_node->set_abstract(abstract_tensor); + outputFuncGraph->set_return(return_node); + MS_LOG(INFO) << "Construct funcgraph finined, all success!"; + } + return true; +} + +bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); + CNodePtr cnode_ptr = nullptr; + for (int i = 0; i < importProto.node_size(); ++i) { + const onnx::NodeProto &node_proto = importProto.node(i); + const std::string &node_type = node_proto.op_type(); + if (node_type == kConstantValueNode) { + if (!BuildValueNodeForFuncGraph(node_proto)) { + MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i; + return false; + } + continue; + } + cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto); + if (cnode_ptr == nullptr) { + MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; + return false; + } + } + + BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr); + return true; +} + +bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); + MS_EXCEPTION_IF_NULL(debug_info_ptr); + if (importProto.has_name()) { + debug_info_ptr->set_name(importProto.name()); + } else { + MS_LOG(ERROR) << "FuncGraph under converting has not name!"; + } + + if (!ImportParametersForGraph(outputFuncGraph, importProto)) { + return false; + } + return ImportNodesForGraph(outputFuncGraph, importProto); +} + +bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { + if (!model_proto.has_producer_name()) { + MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; + return false; + } + producer_name_ = model_proto.producer_name(); + + if (!model_proto.has_model_version()) { + MS_LOG(ERROR) << "Parse model producer version from pb file failed!"; + return false; + } + model_version_ = model_proto.model_version(); + + if (!model_proto.has_ir_version()) { + MS_LOG(ERROR) << "Parse model version from pb file failed!"; + return false; + } + ir_version_ = model_proto.ir_version(); + return true; +} + + +int AnfImporterFromProtobuf::Import() { + FuncGraphPtr dstGraph = std::make_shared(); + MS_EXCEPTION_IF_NULL(dstGraph); + if (!ParseModelConfigureInfo(*onnx_model_)) { + MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; + } + const onnx::GraphProto &graphBuild = onnx_model_->graph(); + if (!BuildFuncGraph(dstGraph, graphBuild)) { + MS_LOG(ERROR) << "Build funcgraph failed!"; + return RET_ERROR; + } + func_graph_ = dstGraph; + MS_LOG(INFO) << "Parse pb to build FuncGraph Success!"; + return RET_OK; +} + + +onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { + std::unique_ptr onnx_file(new(std::nothrow) char[PATH_MAX]{0}); + if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) { + MS_LOG(ERROR) << "open file failed."; + return nullptr; + } + int fd = open(onnx_file.get(), O_RDONLY); + google::protobuf::io::FileInputStream input(fd); + google::protobuf::io::CodedInputStream code_input(&input); + code_input.SetTotalBytesLimit(INT_MAX, 536870912); + auto onnx_model = new onnx::ModelProto; + bool ret = onnx_model->ParseFromCodedStream(&code_input); + if (!ret) { + MS_LOG(ERROR) << "load onnx file failed"; + delete onnx_model; + return nullptr; + } + (void) close(fd); + MS_LOG(INFO) << "enter ReadProtoFromBinary success!" << std::endl; + return onnx_model; +} + +FuncGraphPtr AnfImporterFromProtobuf::GetResult() { return this->func_graph_; } +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/common/anf_importer/import_from_protobuf.h b/mindspore/lite/src/common/anf_importer/import_from_protobuf.h new file mode 100644 index 00000000000..b4fbea9eafe --- /dev/null +++ b/mindspore/lite/src/common/anf_importer/import_from_protobuf.h @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ +#define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ + +#include +#include +#include +#include + +#include "tools/converter/parser/onnx/onnx.pb.h" +#include "src/common/anf_importer/anf_importer.h" +#include "abstract/abstract_value.h" + +namespace mindspore::lite { +class AnfImporterFromProtobuf : public AnfImporter { + public: + explicit AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) + : onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {} + + ~AnfImporterFromProtobuf() override = default; + + static onnx::ModelProto *ReadOnnxFromBinary(const std::string &model_path); + + FuncGraphPtr GetResult() override; + + int Import() override; + + private: + void ConverterConstTensor() override {}; + int ConverterCNode() override {}; + void AddReturnCNode() override {}; + bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto); + bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto); + bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto); + bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto); + bool BuildParameterForFuncGraph(const ParameterPtr &node, + const onnx::ValueInfoProto &value_proto); + CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::NodeProto &node_proto); + bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto, + const CNodePtr &cnode_ptr); + bool GetAttrValueForCNode(const PrimitivePtr &prim, + const onnx::AttributeProto &attr_proto); + bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, + const std::string &attr_name, + const onnx::TensorProto &attr_tensor); + ValuePtr ObtainCNodeAttrInScalarForm(const onnx::TensorProto &attr_tensor); + bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, + const std::string &attr_name, + const onnx::TensorProto &attr_tensor); + bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto); + bool ObtainValueNodeInTensorForm(const std::string &value_node_name, + const onnx::TensorProto &attr_tensor); + bool GetAttrValueForValueNode(const std::string &value_node_name, + const onnx::AttributeProto &attr_tensor); + bool ObtainValueNodeInTypeForm(const std::string &value_node_name, + const onnx::TensorProto &attr_tensor); + std::unordered_map + GetAbstractForCNode(const onnx::AttributeProto &attr_proto); + + private: + std::string producer_name_; + int model_version_{}; + int ir_version_{}; + std::unordered_map anfnode_build_map_; + std::map default_para_map_; + onnx::ModelProto *onnx_model_; + FuncGraphPtr func_graph_; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_PROTOBUF_H_ + diff --git a/mindspore/lite/src/common/common.h b/mindspore/lite/src/common/common.h new file mode 100755 index 00000000000..ed12c49686d --- /dev/null +++ b/mindspore/lite/src/common/common.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_COMMON_COMMON_H_ +#define MINDSPORE_LITE_COMMON_COMMON_H_ + +#include +#include "schema/model_generated.h" + +namespace mindspore { +namespace lite { +enum NCHW_SHAPE { NCHW_N = 0, NCHW_C = 1, NCHW_H = 2, NCHW_W = 3 }; +enum NHWC_SHAPE { NHWC_N = 0, NHWC_H = 1, NHWC_W = 2, NHWC_C = 3 }; +enum HWCK_SHAPE { HWCK_H = 0, HWCK_W = 1, HWCK_C = 2, HWCK_K = 3 }; +enum HWKC_SHAPE { HWKC_H = 0, HWKC_W = 1, HWKC_K = 2, HWKC_C = 3 }; +enum KCHW_SHAPE { KCHW_K = 0, KCHW_C = 1, KCHW_H = 2, KCHW_W = 3 }; +enum CKHW_SHAPE { CKHW_C = 0, CKHW_K = 1, CKHW_H = 2, CKHW_W = 3 }; +enum CHWK_SHAPE { CHWK_C = 0, CHWK_H = 1, CHWK_W = 2, CHWK_K = 3 }; +enum KHWC_SHAPE { KHWC_K = 0, KHWC_H = 1, KHWC_W = 2, KHWC_C = 3 }; +enum CHW_SHAPE { CHW_C = 0, CHW_H = 1, CHW_W = 2 }; +enum HWC_SHAPE { HWC_H = 0, HWC_W = 1, HWC_C = 2 }; +static constexpr int kNCHWDimNumber = 4; +static constexpr int kNHWCDimNumber = 4; + +static constexpr int TENSOR_MAX_REFCOUNT = 999; + +static const char *DELIM_COLON = ":"; +static const char *DELIM_COMMA = ","; +static const char *DELIM_SLASH = "/"; +static const char *DELIM_DOUBLE_BACKSLASH = "\\"; + +// quantization relative +static const char QUANTIZED_UINT8[] = "QUANTIZED_UINT8"; +static const char QUANTIZED_INT8[] = "QUANTIZED_INT8"; +static const char QUANTIZED_INT16[] = "QUANTIZED_INT16"; +static const char QUANTIZED_UINT16[] = "QUANTIZED_UINT16"; +static const char QUANTIZED_FLOAT16[] = "FLOAT16"; +static const char QUANTIZED_FLOAT32[] = "FLOAT32"; +static const char QUANTIZATION_TYPE_DYNAMIC[] = "DYNAMIC"; +static const char QUANTIZATION_TYPE_STATIC[] = "STATIC"; +static const char CALIB_NORM[] = "NORM"; + +// dims +static const int32_t DIM_DEFAULT_SIZE = 4; + +static const schema::Format DEFAULT_FORMAT = schema::Format_NCHW; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_COMMON_COMMON_H_ + diff --git a/mindspore/lite/src/common/file_utils.cc b/mindspore/lite/src/common/file_utils.cc new file mode 100644 index 00000000000..01b905db42b --- /dev/null +++ b/mindspore/lite/src/common/file_utils.cc @@ -0,0 +1,168 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "src/common/file_utils.h" +#include "securec/include/securec.h" + +namespace mindspore { +namespace lite { +#define MAX_FILENAME_LEN 1024 +char *ReadFile(const char *file, size_t *size) { + if (file == nullptr) { + MS_LOG(ERROR) << "file is nullptr"; + return nullptr; + } + MS_ASSERT(size != nullptr); + std::string realPath = RealPath(file); + std::ifstream ifs(realPath); + if (!ifs.good()) { + MS_LOG(ERROR) << "file: " << realPath << " is not exist"; + return nullptr; + } + + if (!ifs.is_open()) { + MS_LOG(ERROR) << "file: " << realPath << " open failed"; + return nullptr; + } + + ifs.seekg(0, std::ios::end); + *size = ifs.tellg(); + std::unique_ptr buf(new (std::nothrow) char[*size]); + if (buf == nullptr) { + MS_LOG(ERROR) << "malloc buf failed, file: " << realPath; + ifs.close(); + return nullptr; + } + + ifs.seekg(0, std::ios::beg); + ifs.read(buf.get(), *size); + ifs.close(); + + return buf.release(); +} + +std::string RealPath(const char *path) { + if (path == nullptr) { + MS_LOG(ERROR) << "path is nullptr"; + return ""; + } + if ((strlen(path)) >= PATH_MAX) { + MS_LOG(ERROR) << "path is too long"; + return ""; + } + std::shared_ptr resolvedPath(new (std::nothrow) char[PATH_MAX]{0}); + if (resolvedPath == nullptr) { + MS_LOG(ERROR) << "new resolvedPath failed"; + return ""; + } + std::string realPath = realpath(path, resolvedPath.get()); + if (realPath.empty()) { + MS_LOG(ERROR) << "Proto file path is not valid"; + return ""; + } + std::string res = resolvedPath.get(); + + return res; +} + +int WriteToBin(const std::string &file_path, void *data, size_t size) { + std::ofstream out_file; + + out_file.open(file_path.c_str(), std::ios::binary); + if (!out_file.good()) { + return -1; + } + + if (!out_file.is_open()) { + out_file.close(); + return -1; + } + out_file.write(reinterpret_cast(data), size); + return 0; +} + +int CompareOutputData(float *output_data, float *correct_data, int data_size) { + float error = 0; + for (size_t i = 0; i < data_size; i++) { + float abs = fabs(output_data[i] - correct_data[i]); + if (abs > 0.00001) { + error += abs; + } + } + error /= data_size; + if (error > 0.0001) { + printf("has accuracy error!\n"); + printf("%f\n", error); + return 1; + } + return 0; +} + +void CompareOutput(float *output_data, std::string file_path) { + size_t output_size; + auto ground_truth = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); + size_t output_num = output_size / sizeof(float); + printf("output num : %zu\n", output_num); + CompareOutputData(output_data, ground_truth, output_num); +} + +// std::string GetAndroidPackageName() { +// static std::string packageName; +// +// if (!packageName.empty()) { +// return packageName; +// } +// +// char cmdline[MAX_FILENAME_LEN] = {0}; +// int fd = open("/proc/self/cmdline", O_RDONLY); +// +// if (fd >= 0) { +// char ch; +// int i = 0; +// while (read(fd, &ch, sizeof(ch)) > 0 && !isspace(ch)) { +// if (':' == ch) { +// break; +// } +// +// if (('/' == ch) || ('\\' == ch)) { +// (void)memset(cmdline, 0, sizeof(cmdline)); +// i = 0; +// } else { +// cmdline[i] = ch; +// i++; +// } +// } +// close(fd); +// } +// packageName = std::string(cmdline); +// return packageName; +//} + +// std::string GetAndroidPackagePath() { +// std::string packageName = GetAndroidPackageName(); +// if (packageName.empty()) { +// return "./"; +// } +// return "/data/data/" + packageName + '/'; +//} + +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/src/common/file_utils.h b/mindspore/lite/src/common/file_utils.h new file mode 100644 index 00000000000..ff1ec03e641 --- /dev/null +++ b/mindspore/lite/src/common/file_utils.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_COMMON_FILE_UTILS_H_ +#define MINDSPORE_LITE_COMMON_FILE_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "src/common/utils.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace lite { +char *ReadFile(const char *file, size_t *size); + +std::string RealPath(const char *path); + +template +void WriteToTxt(const std::string& file_path, void *data, size_t element_size) { + std::ofstream out_file; + out_file.open(file_path, std::ios::out); + auto real_data = reinterpret_cast(data); + for (size_t i = 0; i < element_size; i++) { + out_file << real_data[i] << " "; + } + out_file.close(); +} + +int WriteToBin(const std::string& file_path, void *data, size_t size); + +int CompareOutputData(float *output_data, float *correct_data, int data_size); +void CompareOutput(float *output_data, std::string file_path); + +std::string GetAndroidPackageName(); +std::string GetAndroidPackagePath(); +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_COMMON_FILE_UTILS_H_ + diff --git a/mindspore/lite/src/common/graph_util.cc b/mindspore/lite/src/common/graph_util.cc new file mode 100755 index 00000000000..c8472bd7f92 --- /dev/null +++ b/mindspore/lite/src/common/graph_util.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "src/common/graph_util.h" +#include "src/common/utils.h" +#include "utils/log_adapter.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +std::vector GetGraphInputNodes(const schema::MetaGraph *meta_graph) { + MS_ASSERT(nullptr != meta_graph); + std::vector ret; + for (size_t i = 0; i < meta_graph->inputIndex()->size(); i++) { + auto input_index = meta_graph->inputIndex()->GetAs(i); + for (size_t j = 0; j < meta_graph->nodes()->size(); j++) { + auto *cNode = meta_graph->nodes()->GetAs(j); + MS_ASSERT(nullptr != cNode); + for (size_t k = 0; k < cNode->inputIndex()->size(); k++) { + if (cNode->inputIndex()->GetAs(k) == input_index) { + ret.emplace_back(j); + break; + } + } + } + } + return std::move(ret); +} + +std::vector GetGraphOutputNodes(const schema::MetaGraph *meta_graph) { + MS_ASSERT(nullptr != meta_graph); + std::vector ret; + for (size_t i = 0; i < meta_graph->outputIndex()->size(); i++) { + auto output_index = meta_graph->outputIndex()->GetAs(i); + for (size_t j = 0; j < meta_graph->nodes()->size(); j++) { + auto *cNode = meta_graph->nodes()->GetAs(j); + MS_ASSERT(nullptr != cNode); + for (size_t k = 0; k < cNode->outputIndex()->size(); k++) { + if (cNode->outputIndex()->GetAs(k) == output_index) { + ret.emplace_back(j); + break; + } + } + } + } + return std::move(ret); +} + +// NODE_ID OpNode::ID() { return id; } +// +// void OpNode::AddInEdge(NODE_ID nodeId) { inEdges.insert(nodeId); } +// +// void OpNode::AddOutEdge(NODE_ID nodeId) { outEdges.insert(nodeId); } +// +// std::unordered_set OpNode::GetAllInEdges() { return inEdges; } +// +// std::unordered_set OpNode::GetAllOutEdges() { return outEdges; } + +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/src/common/graph_util.h b/mindspore/lite/src/common/graph_util.h new file mode 100755 index 00000000000..e9a9e994fe8 --- /dev/null +++ b/mindspore/lite/src/common/graph_util.h @@ -0,0 +1,250 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_COMMON_GRAPH_UTIL_H_ +#define MINDSPORE_LITE_COMMON_GRAPH_UTIL_H_ + +#include +#include +#include +#include +#include +#include "schema/model_generated.h" +#include "utils//log_adapter.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +using NODE_ID = std::string; + +std::vector GetGraphInputNodes(const schema::MetaGraph *meta_graph); + +std::vector GetGraphOutputNodes(const schema::MetaGraph *meta_graph); + +class OpNode { + public: + explicit OpNode(const NODE_ID &nodeId) : id(nodeId) {} + NODE_ID ID() { return id; }; + void AddInEdge(NODE_ID nodeId) { inEdges.insert(nodeId); } + void AddOutEdge(NODE_ID nodeId) { outEdges.insert(nodeId); } + std::unordered_set GetAllInEdges() { return inEdges; } + std::unordered_set GetAllOutEdges() { return outEdges; } + + protected: + NODE_ID id; + std::unordered_set inEdges; + std::unordered_set outEdges; +}; + + +template +class OpGraph { + public: + OpGraph() {} + + ~OpGraph(); + + int Build(const schema::MetaGraph *subGraphDef); + NODE_T *GetNode(NODE_ID nodeId); + NODE_T *AddNode(NODE_ID nodeId); + std::unordered_set GetInputNode(); + std::unordered_set GetOutputNode(); + + void AddNodes(std::vector addNodes); + void DeleteNodes(std::vector deleteNodes); + + void AddEdge(NODE_ID nodeId); + int AddEdge(NODE_ID srcId, NODE_ID dstId); + int AddEdge(const schema::CNode *srcNodeDef, const flatbuffers::Vector> *opDefs); + std::unordered_map> GetDepends(); + + protected: + std::unordered_map nodes; +}; + +template +int OpGraph::Build(const schema::MetaGraph *subGraphDef) { + if (subGraphDef == nullptr) { + // MS_LOGE("subGraphDef is nullptr"); + return RET_ERROR; + } + + + auto opDefs = subGraphDef->nodes(); + + uint32_t opCount = opDefs->size(); + for (uint32_t i = 0; i < opCount; i++) { + auto opDef = opDefs->GetAs(i); + auto node = AddNode(std::string(opDef->name()->c_str())); + if (node == nullptr) { + // MS_LOGE("add srcNode failed,name %s", opDef->name()->c_str()); + return RET_ERROR; + } + auto ret = AddEdge(opDef, opDefs); + if (ret != RET_OK) { + // MS_LOGE("%s add edge failed. ret:%d", opDef->name()->c_str(), ret); + return RET_ERROR; + } + } + + return RET_OK; +} +template +int OpGraph::AddEdge(const schema::CNode *srcNodeDef, + const flatbuffers::Vector> *nodeDefs) { + MS_ASSERT(srcNodeDef != nullptr); + MS_ASSERT(nodeDefs != nullptr); + NODE_ID srcId = std::string(srcNodeDef->name()->c_str()); + uint32_t opCount = nodeDefs->size(); + // for single op condition + AddNode(srcId); + for (auto index : *(srcNodeDef->outputIndex())) { + for (uint32_t i = 0; i < opCount; i++) { + auto dstNodeDef = nodeDefs->GetAs(i); + bool find = false; + auto inputIndex = dstNodeDef->inputIndex(); + if (std::any_of(inputIndex->begin(), inputIndex->end(), [&index](int i) { return i == index; })) { + find = true; + } + + if (!find) { + continue; + } + NODE_ID dstId = std::string(dstNodeDef->name()->c_str()); + auto ret = AddEdge(srcId, dstId); + if (ret != RET_OK) { + return ret; + } + } + } + + return RET_OK; +} + +template +int OpGraph::AddEdge(NODE_ID srcId, NODE_ID dstId) { + auto srcNode = AddNode(srcId); + if (srcNode == nullptr) { + // MS_LOGE("add srcNode failed"); + return RET_ERROR; + } + auto dstNode = AddNode(dstId); + if (dstNode == nullptr) { + // MS_LOGE("add dstNode failed"); + return RET_ERROR; + } + + srcNode->AddOutEdge(dstNode); + + dstNode->AddInEdge(srcNode); + return RET_OK; +} + +template +NODE_T *OpGraph::GetNode(NODE_ID nodeId) { + auto node = nodes.find(nodeId); + if (node == nodes.end()) { + return nullptr; + } + return node->second; +} + +template +NODE_T *OpGraph::AddNode(NODE_ID nodeId) { + auto node = GetNode(nodeId); + if (node != nullptr) { + return node; + } + node = new (std::nothrow) NODE_T(nodeId); + if (node == nullptr) { + // MS_LOGE("new node failed"); + return nullptr; + } + nodes[nodeId] = node; + return node; +} + +template +void OpGraph::AddNodes(std::vector addNodes) { + for (auto node : addNodes) { + if (node == nullptr) { + return; + } + + nodes[node->ID()] = node; + } +} + +template +void OpGraph::DeleteNodes(std::vector deleteNodes) { + for (auto deletenode : deleteNodes) { + if (deletenode == nullptr) { + continue; + } + auto node = GetNode(deletenode->ID()); + if (node == nullptr) { + continue; + } + nodes.erase(deletenode->ID()); + } +} + +template +std::unordered_set OpGraph::GetInputNode() { + std::unordered_set inputNodes; + for (const auto &iter : nodes) { + auto node = iter.second; + if (node->GetAllInEdges().empty()) { + inputNodes.insert(node); + } + } + return inputNodes; +} + +template +std::unordered_set OpGraph::GetOutputNode() { + std::unordered_set outputNodes; + for (const auto &iter : nodes) { + auto node = iter.second; + if (node->GetAllOutEdges().empty()) { + outputNodes.insert(node); + } + } + return outputNodes; +} + +template +std::unordered_map> OpGraph::GetDepends() { + std::unordered_map> depends; + for (auto nodeIter : nodes) { + depends[nodeIter.second] = nodeIter.second->GetAllInEdges(); + } + return depends; +} + +template +OpGraph::~OpGraph() { + for (auto iter : nodes) { + delete iter.second; + } + nodes.clear(); +} + +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_COMMON_GRAPH_UTIL_H_ + diff --git a/mindspore/lite/src/common/graph_utils_extends.cc b/mindspore/lite/src/common/graph_utils_extends.cc new file mode 100644 index 00000000000..828ddb68196 --- /dev/null +++ b/mindspore/lite/src/common/graph_utils_extends.cc @@ -0,0 +1,151 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/graph_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ir/visitor.h" +#include "ir/func_graph.h" +#include "debug/label.h" +#include "utils/log_adapter.h" +#include "src/common/utils.h" + +namespace mindspore { +namespace { +class DeepFirstSearcher { + public: + explicit DeepFirstSearcher(const IncludeFunc &include) : include_(include) {} + ~DeepFirstSearcher() = default; + + std::vector Search(const AnfNodePtr &root) { + if (root == nullptr) { + return res_; + } + seen_ = NewSeenGeneration(); + Visit(root); + return res_; + } + + void Visit(const AnfNodePtr &node) { + if (node == nullptr) { + return; + } + if (node->seen_ == seen_) { + return; + } + + node->seen_ = seen_; + + auto incl = include_(node); + if (incl == EXCLUDE) { + return; + } + if (filter_ == nullptr || !filter_(node)) { + res_.push_back(node); + } + if (incl == FOLLOW) { + if (node->isa()) { + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { + Visit(*iter); + } + return; + } + } + } + + private: + size_t seen_{0}; + IncludeFunc include_; + FilterFunc filter_; + std::vector res_{}; +}; + +class DeepScopedGraphSearcher : public DeepFirstSearcher { + public: + explicit DeepScopedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} + ~DeepScopedGraphSearcher() = default; + + void Visit(const CNodePtr &cnode) { return; } + + void Visit(const ValueNodePtr &vnode) { + if (!IsValueNode(vnode)) { + return; + } + + auto graph = GetValueNode(vnode); + AnfNodePtr ret = graph->get_return(); + if (ret != nullptr) { + DeepFirstSearcher::Visit(ret); + } + } + + void Visit(const ParameterPtr ¶m) { + if (param->func_graph() == nullptr) { + return; + } + + AnfNodePtr ret = param->func_graph()->get_return(); + if (ret != nullptr) { + DeepFirstSearcher::Visit(ret); + } + } +}; + +class DeepUsedGraphSearcher : public DeepFirstSearcher { + public: + explicit DeepUsedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} + ~DeepUsedGraphSearcher() = default; + + void Visit(const CNodePtr &cnode) { return; } + + void Visit(const ValueNodePtr &vnode) { return; } +}; + +class DeepLinkedGraphSearcher : public DeepFirstSearcher { + public: + explicit DeepLinkedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} + ~DeepLinkedGraphSearcher() = default; + + void Visit(const CNodePtr &cnode) { return; } + + void Visit(const ValueNodePtr &) {} +}; +} // namespace + +std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { + return DeepScopedGraphSearcher(include).Search(root); +} + +std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { + return DeepUsedGraphSearcher(include).Search(root); +} + +std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { + return DeepLinkedGraphSearcher(include).Search(root); +} + +} // namespace mindspore + diff --git a/mindspore/lite/src/common/op_utils.h b/mindspore/lite/src/common/op_utils.h new file mode 100755 index 00000000000..68a42171146 --- /dev/null +++ b/mindspore/lite/src/common/op_utils.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_COMMON_OP_UTILS_H_ +#define MINDSPORE_LITE_COMMON_OP_UTILS_H_ + +#include +#include +#include "schema/model_generated.h" + +namespace mindspore { +namespace lite { +inline schema::PrimitiveType GetOpType(const schema::CNode &opDef) { return opDef.primitive()->value_type(); } +inline std::string GetOpTypeName(const schema::CNode &opDef) { return schema::EnumNamePrimitiveType(GetOpType(opDef)); } +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_COMMON_OP_UTILS_H_ + diff --git a/mindspore/lite/src/common/utils.cc b/mindspore/lite/src/common/utils.cc new file mode 100644 index 00000000000..bb2e1e9c2b8 --- /dev/null +++ b/mindspore/lite/src/common/utils.cc @@ -0,0 +1,262 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef __ANDROID__ +#include +#endif +#include "src/common/utils.h" + +namespace mindspore { +namespace lite { +std::vector StringSplit(std::string str, const std::string& pattern) { + std::vector result; + if (str.empty()) { + return result; + } + std::string::size_type pos; + str += pattern; + auto size = str.size(); + + for (size_t i = 0; i < size; i++) { + pos = str.find(pattern, i); + if (pos < size) { + std::string s = str.substr(i, pos - i); + result.push_back(s); + i = pos + pattern.size() - 1; + } + } + return result; +} + +uint64_t GetTimeUs() { + struct timespec ts = {0, 0}; + if (clock_gettime(CLOCK_MONOTONIC, &ts) != 0) { + return 0; + } + // USECS_IN_SEC *NSECS_IN_USEC; + uint64_t retval = static_cast((ts.tv_sec * USEC) + (ts.tv_nsec / MSEC)); + return retval; +} + +static const unsigned int FP32_BIT_SIZE = 32; +static const unsigned int FP32_EXPONENT_BIAS = 127; +static const unsigned int FP32_SIGNIFICAND = 23; + +static const unsigned int FP32_EXPONENT_MAX = 255; + +static const unsigned int FP16_BIT_SIZE = 16; +static const unsigned int FP16_EXPONENT_BIAS = 15; +static const unsigned int FP16_SIGNIFICAND = 10; + +static const int FP16_EXPONENT_MAX = 30; +static const int FP16_EXPONENT_MIN = -10; + +// fp16.c +float ShortToFloat32(int16_t srcValue) { + uint16_t expHalf16 = srcValue & 0x7C00; + int exp1 = static_cast(expHalf16); + uint16_t mantissa16 = srcValue & 0x03FF; + int mantissa1 = static_cast(mantissa16); + int sign = static_cast(srcValue & 0x8000); + sign = sign << FP16_BIT_SIZE; + + // nan or inf + if (expHalf16 == 0x7C00) { + // nan + if (mantissa16 > 0) { + int res = (0x7FC00000 | sign); + int *iRes = &res; + auto fres = static_cast(*iRes); + return fres; + } + // inf + int res = (0x7F800000 | sign); + int *iRes = &res; + auto fres = static_cast(*iRes); + return fres; + } + if (expHalf16 != 0) { + exp1 += ((FP32_EXPONENT_BIAS - FP16_EXPONENT_BIAS) << FP16_SIGNIFICAND); // exponents converted to float32 bias + int res = (exp1 | mantissa1); + res = res << (FP32_SIGNIFICAND - FP16_SIGNIFICAND); + res = (res | sign); + int *iRes = &res; + auto fres = static_cast(*iRes); + return fres; + } + + int xmm1 = exp1 > (1 << FP16_SIGNIFICAND) ? exp1 : (1 << FP16_SIGNIFICAND); + xmm1 = (xmm1 << (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + xmm1 += ((FP32_EXPONENT_BIAS - FP16_EXPONENT_BIAS - FP16_SIGNIFICAND) + << FP32_SIGNIFICAND); // add the bias difference to xmm1 + xmm1 = xmm1 | sign; // Combine with the sign mask + + auto res = static_cast(mantissa1); // Convert mantissa to float + int *ixmm1 = nullptr; + ixmm1 = &xmm1; + res *= static_cast(*ixmm1); + + return res; +} + +// __gnu_f2h_ieee +int16_t Float32ToShort(float srcValue) { + float *psrcValue = nullptr; + psrcValue = &srcValue; + auto srcValueBit = static_cast(*psrcValue); + int sign = srcValueBit >> (FP32_BIT_SIZE - 1); + int mantissa = srcValueBit & 0x007FFFFF; + // exponent + int exp = ((srcValueBit & 0x7F800000) >> FP32_SIGNIFICAND) + FP16_EXPONENT_BIAS - FP32_EXPONENT_BIAS; + int16_t res; + if (exp > 0 && exp < FP16_EXPONENT_MAX) { + // use rte rounding mode, round the significand, combine sign, exponent and significand into a short. + res = (sign << (FP16_BIT_SIZE - 1)) | (exp << FP16_SIGNIFICAND) | + ((mantissa + 0x00001000) >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + } else if (srcValueBit == 0) { + res = 0; + } else { + if (exp <= 0) { + if (exp < FP16_EXPONENT_MIN) { + // value is less than min half float point + res = 0; + } else { + // normalized single, magnitude is less than min normal half float point. + mantissa = (mantissa | 0x00800000) >> (1 - exp); + // round to nearest + if ((mantissa & 0x00001000) > 0) { + mantissa = mantissa + 0x00002000; + } + // combine sign & mantissa (exp is zero to get denormalized number) + res = (sign << FP16_EXPONENT_BIAS) | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + } + } else if (exp == (FP32_EXPONENT_MAX - FP32_EXPONENT_BIAS + FP16_EXPONENT_BIAS)) { + if (mantissa == 0) { + // input float is infinity, return infinity half + res = (sign << FP16_EXPONENT_BIAS) | 0x7C00; + } else { + // input float is NaN, return half NaN + res = (sign << FP16_EXPONENT_BIAS) | 0x7C00 | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + } + } else { + // exp > 0, normalized single, round to nearest + if ((mantissa & 0x00001000) > 0) { + mantissa = mantissa + 0x00002000; + if ((mantissa & 0x00800000) > 0) { + mantissa = 0; + exp = exp + 1; + } + } + if (exp > FP16_EXPONENT_MAX) { + // exponent overflow - return infinity half + res = (sign << FP16_EXPONENT_BIAS) | 0x7C00; + } else { + // combine sign, exp and mantissa into normalized half + res = (sign << FP16_EXPONENT_BIAS) | (exp << FP16_SIGNIFICAND) | + (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + } + } + } + return res; +} +std::string Remove(const std::string &from, const std::string &subStr, Mode mode) { + std::string result = from; + if (mode == PREFIX) { + if (from.substr(0, subStr.length()) == subStr) { + result = from.substr(subStr.size()); + } + } else if (mode == SUFFIX) { + if (from.rfind(subStr) == from.size() - subStr.size()) { + result = from.substr(0, from.size() - subStr.size()); + } + } else { + size_t index; + while ((index = result.find(subStr)) != std::string::npos) { + result = result.erase(index, subStr.size()); + } + } + + return result; +} + +std::vector StrSplit(const std::string &str, const std::string &pattern) { + std::string::size_type pos; + std::vector result; + std::string tmpStr(str + pattern); + std::string::size_type size = tmpStr.size(); + + for (std::string::size_type i = 0; i < size; i++) { + pos = tmpStr.find(pattern, i); + if (pos < size) { + std::string s = tmpStr.substr(i, pos - i); + result.push_back(s); + i = pos + pattern.size() - 1; + } + } + return result; +} + +std::vector Tokenize(const std::string &src, const std::string &delimiters, + const Option &maxTokenNum) { + if (maxTokenNum.IsSome() && maxTokenNum.Get() == 0) { + return {}; + } + + std::vector tokens; + size_t offset = 0; + + while (true) { + size_t nonDelimiter = src.find_first_not_of(delimiters, offset); + if (nonDelimiter == std::string::npos) { + break; + } + size_t delimiter = src.find_first_of(delimiters, nonDelimiter); + if (delimiter == std::string::npos || (maxTokenNum.IsSome() && tokens.size() == maxTokenNum.Get() - 1)) { + tokens.push_back(src.substr(nonDelimiter)); + break; + } + + tokens.push_back(src.substr(nonDelimiter, delimiter - nonDelimiter)); + offset = delimiter; + } + return tokens; +} + +void ShortToFloat32(const int16_t *srcdata, float *dstdata, size_t elementSize) { + MS_ASSERT(srcdata != nullptr); + MS_ASSERT(dstdata != nullptr); + for (size_t i = 0; i < elementSize; i++) { + dstdata[i] = ShortToFloat32(srcdata[i]); + } +} + +void Float32ToShort(const float *srcdata, int16_t *dstdata, size_t elementSize) { + MS_ASSERT(srcdata != nullptr); + MS_ASSERT(dstdata != nullptr); + for (size_t i = 0; i < elementSize; i++) { + dstdata[i] = Float32ToShort(srcdata[i]); + } +} + +#if defined(__ANDROID__) +uint32_t getHwCap(int hwcap_type) { + uint32_t ret = getauxval(hwcap_type); + return ret; +} +#endif +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/src/common/utils.h b/mindspore/lite/src/common/utils.h new file mode 100644 index 00000000000..b6d28d8992b --- /dev/null +++ b/mindspore/lite/src/common/utils.h @@ -0,0 +1,193 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_COMMON_UTILS_H_ +#define MINDSPORE_LITE_COMMON_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" +#include "tools/common/option.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +const int USEC = 1000000; +const int MSEC = 1000; +std::vector StringSplit(std::string str, const std::string& pattern); + +uint64_t GetTimeUs(void); + +int16_t Float32ToShort(float srcValue); + +float ShortToFloat32(int16_t srcValue); + +void ShortToFloat32(const int16_t *srcdata, float *dstdata, size_t elementSize); + +void Float32ToShort(const float *srcdata, int16_t *dstdata, size_t elementSize); + +#if defined(__arm__) || defined(__aarch64__) +uint32_t getHwCap(int hwcap_type); +#endif + +template +bool IsContain(const std::vector &vec, T element) { + for (auto iter = vec.begin(); iter != vec.end(); iter++) { + if (*iter == element) { + return true; + } + } + return false; +} + +template +bool VectorErase(std::vector *vec, T element) { + bool ret = false; + for (auto iter = vec->begin(); iter != vec->end();) { + if (*iter == element) { + iter = vec->erase(iter); + ret = true; + } else { + iter++; + } + } + return ret; +} + +template +bool VectorReplace(std::vector *vec, T srcElement, T dstElement) { + bool ret = false; + for (auto iter = vec->begin(); iter != vec->end(); iter++) { + if (*iter == srcElement) { + if (!IsContain(*vec, dstElement)) { + *iter = std::move(dstElement); + } else { + vec->erase(iter); + } + ret = true; + break; + } + } + return ret; +} + +const char WHITESPACE[] = "\t\n\v\f\r "; +const char STR_TRUE[] = "true"; +const char STR_FALSE[] = "false"; + +template +Option ToString(T t) { + std::ostringstream out; + out << t; + if (!out.good()) { + return Option(None()); + } + + return Option(out.str()); +} + +template <> +inline Option ToString(bool value) { + return value ? Option(STR_TRUE) : Option(STR_FALSE); +} + +// get the file name from a given path +// for example: "/usr/bin", we will get "bin" +inline std::string GetFileName(const std::string &path) { + char delim = '/'; + + size_t i = path.rfind(delim, path.length()); + if (i != std::string::npos) { + return (path.substr(i + 1, path.length() - i)); + } + + return ""; +} + +// trim the white space character in a string +// see also: macro WHITESPACE defined above +inline void Trim(std::string *input) { + if (input == nullptr) { + return; + } + if (input->empty()) { + return; + } + + input->erase(0, input->find_first_not_of(WHITESPACE)); + input->erase(input->find_last_not_of(WHITESPACE) + 1); +} + +// to judge whether a string is starting with prefix +// for example: "hello world" is starting with "hello" +inline bool StartsWithPrefix(const std::string &source, const std::string &prefix) { + if (source.length() < prefix.length()) { + return false; + } + + return (source.compare(0, prefix.length(), prefix) == 0); +} + +// split string +std::vector StrSplit(const std::string &str, const std::string &pattern); + +// tokenize string +std::vector Tokenize(const std::string &src, const std::string &delimiters, + const Option &maxTokenNum = Option(None())); + +enum Mode { PREFIX, SUFFIX, ANY }; + +// remove redundant charactor +std::string Remove(const std::string &from, const std::string &subStr, Mode mode = ANY); + +template +inline Option GenericParseValue(const std::string &value) { + T ret; + std::istringstream input(value); + input >> ret; + + if (input && input.eof()) { + return Option(ret); + } + + return Option(None()); +} + +template <> +inline Option GenericParseValue(const std::string &value) { + return Option(value); +} + +template <> +inline Option GenericParseValue(const std::string &value) { + if (value == "true") { + return Option(true); + } else if (value == "false") { + return Option(false); + } + + return Option(None()); +} +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_COMMON_UTILS_H_ + diff --git a/mindspore/lite/src/context.cc b/mindspore/lite/src/context.cc new file mode 100644 index 00000000000..4c2b32eb316 --- /dev/null +++ b/mindspore/lite/src/context.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/context.h" +#include "src/runtime/allocator.h" + +namespace mindspore::lite { +Context::Context() { allocator = Allocator::Create(); } + +Context::~Context() = default; + +Context::Context(int threadNum, std::shared_ptr allocator, DeviceContext deviceCtx) { + this->allocator = std::move(allocator); + this->threadNum = threadNum; + this->deviceCtx = std::move(deviceCtx); +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/executor.cc b/mindspore/lite/src/executor.cc new file mode 100644 index 00000000000..b95ddd82a85 --- /dev/null +++ b/mindspore/lite/src/executor.cc @@ -0,0 +1,124 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/src/executor.h" +#include "src/runtime/kernel/arm/opclib/pack.h" +#include "include/errorcode.h" + +namespace mindspore::lite { +int Executor::Run(std::vector &inputs, std::vector &outputs, + std::vector &kernels, Allocator *allocator, + const kernel::KernelCallBack &before, const kernel::KernelCallBack &after) { + MS_ASSERT(nullptr != allocator); + for (auto &inTensor : inputs) { + if (inTensor == nullptr) { + MS_LOG(ERROR) << "Graph input tensor is nullptr"; + return RET_ERROR; + } + if (inTensor->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "Model input tensor should be NHWC"; + return RET_ERROR; + } + } + kernel::LiteKernelUtil::InitTensorRefCount(kernels); + for (auto *kernel : kernels) { + MS_ASSERT(nullptr != kernel); + auto &outputs = kernel->GetOutputs(); + for (auto *output : outputs) { + MS_ASSERT(nullptr != output); + output->MallocData(allocator); + } + kernel::CallBackParam callbackParam; + callbackParam.name_callback_aram = kernel->Name(); + + if (before != nullptr) { + if (!before(kernel->GetInputs(), kernel->GetOutputs(), callbackParam)) { + MS_LOG(ERROR) << "run kernel before_callback failed, name: " << kernel->Name(); + } + } + auto ret = kernel->Run(); + if (0 != ret) { + MS_LOG(ERROR) << "run kernel failed, name: " << kernel->Name(); + return ret; + } + + if (after != nullptr) { + if (!after(kernel->GetInputs(), kernel->GetOutputs(), callbackParam)) { + MS_LOG(ERROR) << "run kernel after_callback failed, name: " << kernel->Name(); + } + } + for (auto input_kernel : kernel->GetInKernels()) { + MS_EXCEPTION_IF_NULL(input_kernel); + ret = input_kernel->DecOutTensorRefCount(allocator); + if (0 != ret) { + MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->Name() << " failed"; + } + } + } + return RET_OK; +} + +int Executor::TransformTensorLayout(tensor::Tensor *tensor, schema::Format dst_format, Allocator *allocator) { + MS_ASSERT(nullptr != tensor); + MS_ASSERT(nullptr != allocator); + MS_ASSERT(4 == tensor->shape().size()); + auto data_type = tensor->data_type(); + switch (data_type) { + case kNumberTypeInt8: + return TransformTensorLayoutUint8(tensor, dst_format, allocator); + case kNumberTypeFloat32: + return TransformTensorLayoutFp32(tensor, dst_format, allocator); + } + return RET_OK; +} + +int Executor::TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format dst_format, Allocator *allocator) { + MS_ASSERT(nullptr != tensor); + MS_ASSERT(nullptr != allocator); + MS_ASSERT(4 == tensor->shape().size()); + auto src_format = tensor->GetFormat(); + if (src_format == schema::Format_NC4HW4 && dst_format == schema::Format_NHWC) { + auto *src_data = tensor->Data(); + auto *dst_data = allocator->Malloc(tensor->Size()); + if (dst_data == nullptr) { + MS_LOG(ERROR) << "Malloc data failed"; + return RET_ERROR; + } + PackNC4HW4ToNHWCFp32(src_data, dst_data, tensor->Batch(), tensor->Height() * tensor->Width(), tensor->Channel()); + tensor->SetData(dst_data); + tensor->SetFormat(dst_format); + allocator->Free(src_data); + return RET_OK; + } else { + MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + << schema::EnumNameFormat(dst_format) << " in float32"; + return RET_ERROR; + } +} + +int Executor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format dst_format, Allocator *allocator) { + MS_ASSERT(nullptr != tensor); + MS_ASSERT(nullptr != allocator); + MS_ASSERT(4 == tensor->shape().size()); + // auto src_format = tensor->GetFormat(); + // todo + MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + << schema::EnumNameFormat(dst_format) << " in uint8"; + return RET_ERROR; +} +} // namespace mindspore::lite + + diff --git a/mindspore/lite/src/executor.h b/mindspore/lite/src/executor.h new file mode 100644 index 00000000000..7a30ee34116 --- /dev/null +++ b/mindspore/lite/src/executor.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXECUTOR_H_ +#define MINDSPORE_LITE_SRC_EXECUTOR_H_ + +#include +#include "src/runtime/allocator.h" +#include "src/lite_kernel.h" + +namespace mindspore::lite { +class Executor { + public: + Executor() = default; + + int Prepare(std::vector &kernels) { return 0; } + + int Run(std::vector &inputs, std::vector &outputs, + std::vector &kernels, Allocator *allocator = nullptr, + const kernel::KernelCallBack &before = nullptr, const kernel::KernelCallBack &after = nullptr); + + protected: + int TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format dst_format, Allocator *allocator = nullptr); + + int TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format dst_format, Allocator *allocator = nullptr); + + int TransformTensorLayout(tensor::Tensor *tensor, schema::Format dst_format, Allocator *allocator = nullptr); + + protected: + Context *context = nullptr; +}; + +} // namespace mindspore::lite +#endif + diff --git a/mindspore/lite/src/gllo/common/node_pass.cc b/mindspore/lite/src/gllo/common/node_pass.cc new file mode 100644 index 00000000000..badd0fb434e --- /dev/null +++ b/mindspore/lite/src/gllo/common/node_pass.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/gllo/common/node_pass.h" + +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/manager.h" + +namespace mindspore { +namespace opt { +bool NodePass::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(func_graph); + + std::unordered_set seen_node; + std::deque todo{func_graph->output()}; + bool changes = false; + while (!todo.empty()) { + AnfNodePtr node = todo.front(); + todo.pop_front(); + if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) { + continue; + } + (void)seen_node.insert(node); + AnfNodePtr new_node = Run(func_graph, node); + bool change = (new_node != nullptr); + if (new_node != nullptr && new_node != node) { + (void)manager->Replace(node, new_node); + (void)seen_node.erase(node); + } else if (new_node == nullptr) { + new_node = node; + } + if (new_node && IsValueNode(new_node)) { + auto const_func_graph = GetValueNode(new_node); + MS_EXCEPTION_IF_NULL(const_func_graph); + todo.push_back(const_func_graph->output()); + } else if (new_node && new_node->isa()) { + auto cnode = new_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto inputs = cnode->inputs(); + (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); + } + changes = changes || change; + } + return changes; +} +} // namespace opt +} // namespace mindspore + diff --git a/mindspore/lite/src/gllo/common/node_pass.h b/mindspore/lite/src/gllo/common/node_pass.h new file mode 100644 index 00000000000..039c09bb8c2 --- /dev/null +++ b/mindspore/lite/src/gllo/common/node_pass.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_ +#define MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_ +#include +#include + +#include "src/gllo/common/pass.h" + +namespace mindspore { +namespace opt { +// @brief ANF Node level optimization base pass +class NodePass : public Pass { + public: + explicit NodePass(const std::string &name) : Pass(name) {} + ~NodePass() override = default; + bool Run(const FuncGraphPtr &func_graph) final; + virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0; +}; +using NodePassPtr = std::shared_ptr; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_ diff --git a/mindspore/lite/src/gllo/common/optimizer.cc b/mindspore/lite/src/gllo/common/optimizer.cc new file mode 100644 index 00000000000..925e02f847c --- /dev/null +++ b/mindspore/lite/src/gllo/common/optimizer.cc @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/gllo/common/optimizer.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "src/gllo/common/pass_manager.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/manager.h" + +namespace mindspore { +namespace opt { +PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph) + : NodePass(name), + multigraph_(multigraph), + pattern_engine_(PatternEngine(std::make_shared(), + std::function(AnfEqual), + std::function(CNodeTypeEqual))), + primitive_vars_(std::make_shared()) {} + +const BaseRef PatternProcessPass::DefinePattern() const { + VarPtr X = std::make_shared(); + return BaseRef({X}); +} + +void PatternProcessPass::Build() { + VarPtr fg = std::make_shared("RootG"); + BaseRef pattern = std::move(DefinePattern()); + pattern_ = SexpToNode(pattern, fg, primitive_vars_.get(), multigraph_); +} + +AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + if (pattern_ == nullptr) { + Build(); + } + + auto empty_equiv = std::make_shared(); + MS_EXCEPTION_IF_NULL(primitive_vars_); + EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv); + if (equiv != nullptr && !equiv->empty()) { + return Process(func_graph, node, equiv); + } + return nullptr; +} + +bool MultipleOutputPatternProcessPass::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + VarPtr fg = std::make_shared("RootG"); + auto empty_equiv = std::make_shared(); + MS_EXCEPTION_IF_NULL(child_primitive_vars_); + EquivPtr another_equiv = + child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node, + *child_primitive_vars_, empty_equiv); + if (another_equiv != nullptr && !another_equiv->empty()) { + return IsShareNodes(equiv, another_equiv); + } + return false; +} + +void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) { + if (pass_manager != nullptr) { + pass_managers_.push_back(pass_manager); + } +} + +FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_only_once) { + MS_EXCEPTION_IF_NULL(func_graph); + run_only_once_ = (pass_managers_.size() == 1) ? true : run_only_once; + auto manager = func_graph->manager(); + if (manager == nullptr) { + manager = Manage(func_graph, false); + func_graph->set_manager(manager); + } + + bool changed = true; + while (changed) { + changed = false; + for (size_t i = 0; i < pass_managers_.size(); ++i) { + const PassManagerPtr &pm = pass_managers_[i]; + if (pm != nullptr && pm->Run(func_graph)) { + changed = true; + } + } + if (run_only_once_) { + break; + } + } + + std::vector func_graphs; + func_graphs.push_back(func_graph); + manager->KeepRoots(func_graphs); + (void)TopoSort(func_graph->get_return()); + return func_graph; +} +} // namespace opt +} // namespace mindspore + diff --git a/mindspore/lite/src/gllo/common/optimizer.h b/mindspore/lite/src/gllo/common/optimizer.h new file mode 100644 index 00000000000..c715f511b54 --- /dev/null +++ b/mindspore/lite/src/gllo/common/optimizer.h @@ -0,0 +1,90 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_OPTIMIZER_H_ +#define MINDSPORE_LITE_SRC_PASS_COMMON_OPTIMIZER_H_ + +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/graph_utils.h" +#include "src/common/utils.h" + +#include "src/gllo/common/pass_manager.h" +#include "src/gllo/common/pattern_engine.h" +#include "src/gllo/common/utils.h" + +namespace mindspore { +namespace opt { +using PatternListType = std::initializer_list; + +class PatternProcessPass : public NodePass { + public: + explicit PatternProcessPass(const std::string &name = "", bool multigraph = true); + ~PatternProcessPass() override = default; + virtual const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0; + virtual const BaseRef DefinePattern() const; + AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override; + + private: + void Build(); + + AnfNodePtr pattern_ = nullptr; + bool multigraph_ = true; + PatternEngine pattern_engine_; + PrimitiveVarMapPtr primitive_vars_; +}; + +class MultipleOutputPatternProcessPass : public PatternProcessPass { + public: + explicit MultipleOutputPatternProcessPass(const std::string &name = "", bool multigraph = true) + : PatternProcessPass(name, multigraph), + child_pattern_engine_(PatternEngine(std::make_shared(), + std::function(AnfEqual), + std::function(CNodeTypeEqual))), + child_primitive_vars_(std::make_shared()) {} + ~MultipleOutputPatternProcessPass() override = default; + virtual BaseRef DefineAnotherPattern() const = 0; + // check two patterns whether share the same nodes or not + virtual bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const = 0; + + protected: + bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const; + PatternEngine child_pattern_engine_; + PrimitiveVarMapPtr child_primitive_vars_; +}; + +class GraphOptimizer { + public: + explicit GraphOptimizer(const std::string &name = "graph_optimizer") : name_(name) {} + virtual ~GraphOptimizer() = default; + + void AddPassManager(const PassManagerPtr &pass_manager); + FuncGraphPtr Optimize(const FuncGraphPtr &func_graph, bool run_only_once = true); + + private: + const std::string name_ = "graph_optimizer"; + std::vector pass_managers_{}; + bool run_only_once_ = true; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_PASS_COMMON_OPTIMIZER_H_ + diff --git a/mindspore/lite/src/gllo/common/pass.h b/mindspore/lite/src/gllo/common/pass.h new file mode 100644 index 00000000000..3a3b6927449 --- /dev/null +++ b/mindspore/lite/src/gllo/common/pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_PASS_H_ +#define MINDSPORE_LITE_SRC_PASS_COMMON_PASS_H_ +#include +#include + +#include "ir/anf.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +// @brief ANF Graph level optimization base pass +class Pass { + public: + explicit Pass(const std::string &name = "pass") : name_(name) {} + virtual ~Pass() = default; + virtual bool Run(const FuncGraphPtr &func_graph) = 0; + virtual std::string name() const { return name_; } + + private: + const std::string name_; +}; +using PassPtr = std::shared_ptr; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_PASS_COMMON_PASS_H_ diff --git a/mindspore/lite/src/gllo/common/pass_manager.cc b/mindspore/lite/src/gllo/common/pass_manager.cc new file mode 100644 index 00000000000..763228e369b --- /dev/null +++ b/mindspore/lite/src/gllo/common/pass_manager.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/gllo/common/pass_manager.h" + +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/manager.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +const std::vector &PassManager::Passes() const { return passes_; } + +void PassManager::AddPass(const PassPtr &pass) { + if (pass != nullptr) { + passes_.push_back(pass); + } +} + +bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector &passes) const { + if (func_graph == nullptr) { + return false; + } + bool changed = false; + size_t num = 0; + for (const auto &pass : passes) { + if (pass != nullptr) { +#if defined(_WIN32) || defined(_WIN64) + auto start_time = std::chrono::steady_clock::now(); +#else + struct timeval start_time {}; + struct timeval end_time {}; + (void)gettimeofday(&start_time, nullptr); +#endif + if (pass->Run(func_graph)) { + MS_LOG(DEBUG) << "Run pass and find change"; + changed = true; + } +#if defined(_WIN32) || defined(_WIN64) + auto end_time = std::chrono::steady_clock::now(); + std::chrono::duration> cost = end_time - start_time; + MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost.count() << " us"; +#else + (void)gettimeofday(&end_time, nullptr); + const uint64_t kUSecondInSecond = 1000000; + uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + cost += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost << " us"; +#endif + num++; + } + } + return changed; +} + +bool PassManager::Run(const FuncGraphPtr &func_graph) const { + bool changed = false; + // run all passes + bool change = true; + while (change) { + change = Run(func_graph, passes_); + changed = change || changed; + if (run_only_once_) { + break; + } + } + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/src/gllo/common/pass_manager.h b/mindspore/lite/src/gllo/common/pass_manager.h new file mode 100644 index 00000000000..d9cbd3a5671 --- /dev/null +++ b/mindspore/lite/src/gllo/common/pass_manager.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_ +#define MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_ + +#include +#include +#include +#include + +#include "src/gllo/common/pass.h" +#include "src/gllo/common/node_pass.h" + +namespace mindspore { +namespace opt { +// @brief For optimization passes management +class PassManager { + public: + explicit PassManager(const std::string &name = "pm", bool run_only_once = true) + : name_(name), passes_{}, run_only_once_(run_only_once) {} + virtual ~PassManager() = default; + // Get all the passes added by AddPass + const std::vector &Passes() const; + // Add graph pass, the pass object will be freed when pass manager freed. + void AddPass(const PassPtr &pass); + // Run passes added in pass manager on the input graph + // @param [inout] graph The graph to be optimized + // @return true, graph changed + // @return false, graph not changed + bool Run(const FuncGraphPtr &func_graph) const; + // Run the given graph passes on the input graph + // @param [inout] graph The graph to be optimized + // @param [in] passes The given graph passes + // @return true, graph changed + // @return false, graph not changed + bool Run(const FuncGraphPtr &func_graph, const std::vector &passes) const; + std::string name() const { return name_; } + + private: + const std::string name_; + std::vector passes_; + bool run_only_once_; +}; +using PassManagerPtr = std::shared_ptr; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_ diff --git a/mindspore/lite/src/gllo/common/pattern_engine.cc b/mindspore/lite/src/gllo/common/pattern_engine.cc new file mode 100644 index 00000000000..b71e31f4c1b --- /dev/null +++ b/mindspore/lite/src/gllo/common/pattern_engine.cc @@ -0,0 +1,365 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/gllo/common/pattern_engine.h" + +#include +#include +#include +#include + +#include "ir/func_graph.h" +#include "mindspore/core/ir/primitive.h" +#include "debug/info.h" +#include "ir/anf.h" +#include "utils/convert_utils_base.h" +#include "utils/overload.h" + + +namespace mindspore { +static int GetNextTag() { + static int kID = 0; + return kID++; +} + +void Var::EnsureTag() { + if (tag_.length() == 0) { + std::ostringstream buffer; + buffer << "_" << GetNextTag(); + tag_ = buffer.str(); + } +} + +bool operator==(const VarPtr &lhs, const VarPtr &rhs) { + if (lhs->isa() && rhs->isa()) { + CondVarPtr v1 = dyn_cast(lhs); + CondVarPtr v2 = dyn_cast(rhs); + return *v1 == *v2; + } + + if (lhs->isa() && rhs->isa()) { + SVarPtr v1 = dyn_cast(lhs); + SVarPtr v2 = dyn_cast(rhs); + return *v1 == *v2; + } + return (*lhs == *rhs); +} + +std::string SeqVar::ToString() const { + std::ostringstream buffer; + buffer << "SeqVar(" << tag() << ", " << subvar_->ToString() << ")"; + return buffer.str(); +} + +std::ostream &operator<<(std::ostream &os, const VarPtr &var) { + if (var == nullptr) { + os << ""; + } else { + os << var->ToString(); + } + return os; +} + +template <> + +std::ostream &operator<<(std::ostream &os, const Equiv &equiv) { + os << "[Equiv]" + << "\n"; + for (auto &equiv_item : equiv) { + auto k = equiv_item.first; + os << k << ":"; + BaseRef x = equiv_item.second; + if (utils::isa(x)) { + auto node = utils::cast(x); + os << "TypeString[" << node->type_name() << "]"; + if (IsValueNode(node)) { + os << "IsValueNodeGraph "; + } + os << "type " << node->type_name(); + if (node->isa()) { + os << " value " << GetValueNode(node); + } + os << " addr: " << node; + } else if (utils::isa(x)) { + os << "Named " << x.ToString().c_str(); + } else if (utils::isa(x)) { + os << "TypeString[Var]"; + os << utils::cast(x); + } else if (utils::isa(x)) { + os << "TypeString[Graph]"; + } + os << "\n"; + } + return os; +} + + +static BaseRef GetVar(const BaseRef &x) { + // MS_LOG(DEBUG) << "getVar start :%s" + x.ToString(); + if (utils::isa(x)) { + auto node = utils::cast(x); + // MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]"; + if (node->isa()) { + // MS_LOG(DEBUG) << "IsVarNode " + node->cast()->var_->ToString(); + return node->cast()->var_; + } +// if (node->isa()) { +// MS_LOG(DEBUG) << "value " + GetValueNode(node)->ToString() + " addr: " + node->ToString(); +// } else { +// MS_LOG(DEBUG) << "type " + node->type_name(); +// } +// } else if (utils::isa(x)) { +// MS_LOG(DEBUG) << "Named " + x.ToString(); +// } else if (utils::isa(x)) { +// MS_LOG(DEBUG) << "VectorRef"; +// } else if (utils::isa(x)) { +// MS_LOG(DEBUG) << "TypeString[Var] " + x.ToString(); + } +// MS_LOG(DEBUG) << "GetVar end: " + x.ToString(); + return x; +} + +EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) { + MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString(); + MS_EXCEPTION_IF_NULL(equiv); + if (utils::isa(pattern)) { + VarPtr var = utils::cast(pattern); + if (var->matches(expr)) { + (*equiv)[var] = expr; + MS_LOG(DEBUG) << "pattern is var match: " + pattern.ToString() + ", " + expr.ToString(); + return equiv; + } + } + + return nullptr; +} + +bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, + VectorRef *const values_expr) const { + MS_EXCEPTION_IF_NULL(values_expr); + if (utils::isa(pattern_ref)) { + *values_pattern = pattern_ref; + *values_expr = expr_ref; + return true; + } + return false; +} + +bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern, + VectorRef *const values_expr) const { + MS_EXCEPTION_IF_NULL(values_expr); + // visitor to visite the list + auto appender_pattern = [](VectorRef &values) { + std::function fn = [&](const BaseRef &u) { + values.push_back(GetVar(u)); + return u; + }; + return fn; + }; + + visitor_->SetFn(appender_pattern(*values_pattern)); + // MS_LOG(DEBUG) << "visit pattern_ref"; + bool success = visitor_->Visit(pattern_ref, nullptr); + if (!success) { + return false; + } + + auto appender_expr = [](VectorRef &values) { + std::function fn = [&](const BaseRef &u) { + values.push_back(u); + return u; + }; + return fn; + }; + + visitor_->SetFn(appender_expr(*values_expr)); + // MS_LOG(DEBUG) << "visit expr_ref"; + return visitor_->Visit(expr_ref, nullptr); +} + +static int GetSVarStartIndex(const VectorRef &values) { + int index = -1; + int count = 0; + for (auto &value : values) { + if (utils::isa(value) && utils::cast(value)->isa()) { + if (index != -1) { + // MS_LOG(DEBUG) << "Multiple SVars in sequence"; + return kInvalidVarIndex; + } + index = count; + } + count++; + } + return index; +} + +void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars, + EquivPtr equiv) { + if (equiv == nullptr || values_pattern.empty() || !utils::isa(values_pattern[0]) || + !utils::isa(expr_ref)) { + return; + } + auto real_node = utils::cast(expr_ref); + MS_EXCEPTION_IF_NULL(real_node); + if (!real_node->isa()) { + return; + } + auto prim_node = utils::cast(values_pattern[0]); + MS_EXCEPTION_IF_NULL(prim_node); + if (!IsValueNode(prim_node)) { + return; + } + ValuePtr value = GetValueNode(prim_node); + MS_EXCEPTION_IF_NULL(value); + auto prim = value->cast(); + MS_EXCEPTION_IF_NULL(prim); + auto iter = primitive_vars.find(prim); + if (iter == primitive_vars.end()) { + return; + } + (*equiv)[iter->second] = real_node; +} + +EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, + const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const { + int svar_index = GetSVarStartIndex(values_pattern); + if (svar_index == kInvalidVarIndex) { + return nullptr; + } + + size_t values_pattern_len = values_pattern.size(); + size_t values_expr_len = values_expr.size(); + + if (svar_index == -1) { + if (values_pattern_len != values_expr_len) { + // MS_LOG(DEBUG) << "Structures of differing size: pattern len " << values_pattern_len << ", + // expr len " << values_expr_len; + return nullptr; + } + } + if (values_expr_len < values_pattern_len - 1) { + MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len; + return nullptr; + } + size_t diff = values_expr_len - values_pattern_len + 1; + for (size_t i = 0; i < values_pattern_len; i++) { + size_t expr_i = i; + if (svar_index != -1 && i == IntToSize(svar_index)) { + auto seq = + std::vector(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff)); + equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv); + } else { + if (svar_index != -1 && i > IntToSize(svar_index)) { + expr_i = i + diff - 1; + } + equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv); + } + if (equiv == nullptr) { + return nullptr; + } + } + return equiv; +} + +EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, + EquivPtr equiv) const { + MS_LOG(DEBUG) << "-----[in Match]"; + // MS_LOG(DEBUG) << "GetVar w"; + BaseRef pattern_ref = GetVar(pattern); + // MS_LOG(DEBUG) << "GetVar v"; + BaseRef expr_ref = expr; + + if (equiv == nullptr) { + MS_LOG(EXCEPTION) << "Equiv pointer is null"; + } + + MS_LOG(DEBUG) << "Pattern ref " + pattern_ref.ToString() + ", expr ref" + expr_ref.ToString(); + // 1. if pattern_ref is var and already in equiv, replace it. + if (utils::isa(pattern_ref)) { + VarPtr var = utils::cast(pattern_ref); + auto iter = equiv->find(var); + if (iter != equiv->end()) { + pattern_ref = iter->second; + } + } + + // 2. check equal + if (eq_(pattern_ref, expr_ref)) { + return equiv; + } + + // 3. match var + EquivPtr ret_equiv = MatchOnVar(pattern_ref, expr_ref, equiv); + if (ret_equiv) { + return ret_equiv; + } + + // 4. here the type can be std:vector, std:list, + // or cnode. + if (!type_eq_(pattern_ref, expr_ref)) { + MS_LOG(DEBUG) << "Type mismatch"; + return nullptr; + } + + // 5. transfer the Containers by visitor to std::vector + VectorRef values_pattern; + VectorRef values_expr; + if (!ToVector(pattern_ref, expr_ref, &values_pattern, &values_expr)) { + return nullptr; + } + + // 6. if any svar in both side, find the SeqVar index, + // try to pack the Var s in std::vector to a Seq and match elements one by one. + // check svar + equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv); + UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv); + return equiv; +} + +BaseRef PatternEngine::Replace(const BaseRef &pattern, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(equiv); + MS_LOG(DEBUG) << "-----[in Replace]"; + BaseRef ref = GetVar(pattern); + BaseRef out; + bool is_match = false; + + // w is var + if (utils::isa(ref)) { + const VarPtr &var = utils::cast(ref); + auto iter = equiv->find(var); + if (iter != equiv->end()) { + out = iter->second; + is_match = true; + } + } + if (is_match) { + return out; + } + + // visitor to visit the list + std::function fn = [&, this, equiv](const BaseRef &u) { return Replace(u, equiv); }; + + visitor_->SetFn(fn); + BaseRef visit_out; + if (!visitor_->Visit(pattern, &visit_out)) { + return pattern; + } + return visit_out; +} +} // namespace mindspore + diff --git a/mindspore/lite/src/gllo/common/pattern_engine.h b/mindspore/lite/src/gllo/common/pattern_engine.h new file mode 100644 index 00000000000..ff1502db5f4 --- /dev/null +++ b/mindspore/lite/src/gllo/common/pattern_engine.h @@ -0,0 +1,203 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_ +#define MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/gllo/common/visit.h" +#include "mindspore/core/base/base.h" +#include "utils/log_adapter.h" +#include "base/base_ref.h" + +namespace mindspore { +class CondVar; +class SeqVar; +using CondVarPtr = std::shared_ptr; +using SVarPtr = std::shared_ptr; +const int kInvalidVarIndex = -2; + +using ConditionFunc = std::function; + +// Base wildcard variable which could match any anf node. +class Var : public Base { + friend class VarHasher; + + public: + explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); } + explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) { + EnsureTag(); + } + Var(const Var &other) : Base(other), tag_(other.tag_) {} + virtual Var &operator=(const Var &other) { + if (&other == this) { + return *this; + } + this->tag_ = other.tag_; + return *this; + } + ~Var() override = default; + MS_DECLARE_PARENT(Var, Base); + + virtual bool matches(const BaseRef &) { return true; } + + virtual bool operator==(const Var &other) const { return tag_ == other.tag_; } + bool operator!=(const Var &other) const { return !(&other == this); } + + std::string tag() const { return tag_; } + PrimitivePtr primitive() const { return primitive_; } + std::string ToString() const override { + std::ostringstream buffer; + buffer << "Var(" << tag_ << ")"; + return buffer.str(); + } + std::size_t hash() const override { return std::hash()(tag_); } + + protected: + void EnsureTag(); + + std::string tag_; + PrimitivePtr primitive_; +}; + +// VarNode means variable node, a subclass of AnfNode +class VarNode : public AnfNode { + public: + VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {} + ~VarNode() override = default; + MS_DECLARE_PARENT(VarNode, AnfNode); + + const VarPtr var_; +}; +using VarNodePtr = std::shared_ptr; + +class VarHasher { + public: + std::size_t operator()(const Var &var) const { return var.hash(); } +}; + +// Condition Var, match an anf node when condition function return true. +class CondVar : public Var { + public: + explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {} + ~CondVar() override = default; + MS_DECLARE_PARENT(CondVar, Var); + bool matches(const BaseRef &value) override { + // MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString(); + if (utils::isa(value)) { + return false; + } + return cond_fn_(value); + } + ConditionFunc cond_fn_; +}; + +using Seq = VectorRef; +using SeqPtr = std::shared_ptr; + +// Sequence Var which could match multiple consecutive input nodes of a CNode. +class SeqVar : public Var { + public: + SeqVar() : subvar_(std::make_shared()) {} + ~SeqVar() override = default; + MS_DECLARE_PARENT(SeqVar, Var); + explicit SeqVar(const VarPtr subvar) : subvar_(subvar) {} + bool matches(const BaseRef &value) override { + // match Seq. + if (utils::isa(value)) { + const Seq &seq = utils::cast(value); + return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) { + auto eq = subvar_->matches(v); + return eq; + }); + } + return false; + } + bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; } + std::string ToString() const override; + + private: + VarPtr subvar_; +}; + +bool operator==(const VarPtr &lhs, const VarPtr &rhs); + +inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); } + +std::ostream &operator<<(std::ostream &os, const VarPtr &var); + +using Equiv = std::map; +using EquivPtr = std::shared_ptr; +using PrimitiveVarMap = std::unordered_map; +using PrimitiveVarMapPtr = std::shared_ptr; + +inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); } + +class PatternEngine { + public: + PatternEngine(const std::shared_ptr &visitor, + const std::function &eq, + const std::function &type_eq = DefaultTypeEq) + : visitor_(visitor), eq_(eq), type_eq_(type_eq) {} + ~PatternEngine() = default; + + EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, + EquivPtr equiv) const; + // Replace pattern with equivalent + BaseRef Replace(const BaseRef &pattern, const EquivPtr &equiv) const; + + private: + EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, + const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const; + bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern, + VectorRef *const values_expr) const; + bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, + VectorRef *const values_expr) const; + std::shared_ptr visitor_; + std::function eq_; + std::function type_eq_; +}; +} // namespace mindspore +namespace std { +using mindspore::ERROR; +using mindspore::LogStream; +using mindspore::NoExceptionType; +template <> +struct hash { + std::size_t operator()(const mindspore::VarPtr var) const { + if (var == nullptr) { + MS_LOG(ERROR) << "Invalid var ptr"; + return 0; + } + return std::hash{}(var->tag()); + } +}; +} // namespace std +#endif // MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_ + diff --git a/mindspore/lite/src/gllo/common/utils.cc b/mindspore/lite/src/gllo/common/utils.cc new file mode 100644 index 00000000000..69d3e58a25e --- /dev/null +++ b/mindspore/lite/src/gllo/common/utils.cc @@ -0,0 +1,207 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "src/gllo/common/utils.h" +#include "mindspore/lite/src/ir/primitive_t_value.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { + +bool AnfEqual(const BaseRef &a, const BaseRef &b) { + if (utils::isa(a) && utils::isa(b)) { + return true; + } else if (utils::isa(a) && utils::isa(b)) { + auto a_node = utils::cast(a); + auto b_node = utils::cast(b); + MS_EXCEPTION_IF_NULL(a_node); + MS_EXCEPTION_IF_NULL(b_node); + if (IsValueNode(a_node) && IsValueNode(b_node)) { + auto a_value_node = a_node->cast(); + MS_EXCEPTION_IF_NULL(a_value_node); + auto a_value = a_value_node->value(); + MS_EXCEPTION_IF_NULL(a_value); + auto a_prim = a_value->cast(); + MS_EXCEPTION_IF_NULL(a_prim); + + auto b_value_node = b_node->cast(); + MS_EXCEPTION_IF_NULL(b_value_node); + auto b_value = b_value_node->value(); + MS_EXCEPTION_IF_NULL(b_value); + auto b_prim = b_value->cast(); + MS_EXCEPTION_IF_NULL(b_prim); + + return a_prim->name() == b_prim->name(); + } else if (a_node->isa() && b_node->isa()) { + auto a_value_node_ptr = a_node->cast(); + if (a_value_node_ptr == nullptr) { + MS_LOG(EXCEPTION) << "cast value node ptr fail"; + } + auto a_value_ptr = a_value_node_ptr->value(); + if (a_value_ptr == nullptr) { + MS_LOG(EXCEPTION) << "value ptr is nullptr"; + } + + auto b_value_node_ptr = b_node->cast(); + if (b_value_node_ptr == nullptr) { + MS_LOG(EXCEPTION) << "cast value node ptr fail"; + } + auto b_value_ptr = b_value_node_ptr->value(); + if (b_value_ptr == nullptr) { + MS_LOG(EXCEPTION) << "value ptr is nullptr"; + } + + if (utils::isa(a_value_ptr) && utils::isa(b_value_ptr)) { + auto a_obj = (lite::PrimitiveTValue *)(a_value_ptr.get()); + auto b_obj = (lite::PrimitiveTValue *)(b_value_ptr.get()); + return (*a_obj) == (*b_obj); + } else { + return (*a_value_ptr) == (*b_value_ptr); + } + } + } + + return a == b; +} + +bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { + // To matchCNode and Kernel's type + if (utils::isa(a) && utils::isa(b)) { + return true; + } + return a.type() == b.type(); +} + +namespace { +ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + return nullptr; +} + +CNodePtr CreateCNodeWithGraph(const std::vector &input_nodes, const BaseRef &graph) { + if (utils::isa(graph)) { + return std::make_shared(input_nodes, utils::cast(graph)); + } + if (utils::isa(graph)) { + return std::make_shared(input_nodes, utils::cast(graph)); + } + return nullptr; +} + +VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { + if (utils::isa(graph)) { + // MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); + return std::make_shared(utils::cast(sexp), nullptr); + } + if (utils::isa(graph)) { + // MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); + return std::make_shared(utils::cast(sexp), utils::cast(graph)); + } + MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); + return nullptr; +} + +AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, + bool multigraph) { + // MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); + std::vector input_nodes; + const auto &tuple = utils::cast(sexp); + if (multigraph && utils::isa(graph)) { + for (auto &x : tuple) { + AnfNodePtr node = SexpToNode(x, std::make_shared("G"), primitive_vars, true); + input_nodes.push_back(node); + } + VarPtr var_ptr = utils::cast(graph); + return std::make_shared(input_nodes, var_ptr); + } + + for (auto &x : tuple) { + AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); + input_nodes.push_back(node); + } + return CreateCNodeWithGraph(input_nodes, graph); +} +} // namespace + +AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { + // MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); + MS_EXCEPTION_IF_NULL(primitive_vars); + if (utils::isa(sexp)) { + return HandleSexpVector(sexp, graph, primitive_vars, multigraph); + } + if (utils::isa(sexp)) { + auto var_ptr = utils::cast(sexp); + MS_EXCEPTION_IF_NULL(var_ptr); + if (var_ptr->primitive()) { + (*primitive_vars)[var_ptr->primitive()] = var_ptr; + return NewValueNode(var_ptr->primitive()); + } + return CreateVarNodeWithSexp(sexp, graph); + } + if (utils::isa(sexp)) { + return utils::cast(sexp); + } + auto value_node = CreateValueNodeWithSexp(sexp); + if (value_node == nullptr) { + MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString(); + } + return value_node; +} + +void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph) { + if (graph == nullptr) { + MS_LOG(EXCEPTION) << "The graph is null."; + } +} + +void CheckIfAnfNodeIsNull(const AnfNodePtr &node) { + if (node == nullptr) { + MS_LOG(EXCEPTION) << "The AnfNode is null."; + } +} + +void CheckIfCNodeIsNull(const CNodePtr &node) { + if (node == nullptr) { + MS_LOG(EXCEPTION) << "The CNode is null."; + } +} + +void CheckIfVarIsNull(const VarPtr &var) { + if (var == nullptr) { + MS_LOG(EXCEPTION) << "The Var is null."; + } +} + +void CheckInputSize(const CNodePtr &node, const int size) { + if (node->inputs().size() != size) { + MS_LOG(EXCEPTION) << "The input size of node must be " << size << ", but it is" << node->inputs().size(); + } +} + +} // namespace opt +} // namespace mindspore diff --git a/mindspore/lite/src/gllo/common/utils.h b/mindspore/lite/src/gllo/common/utils.h new file mode 100644 index 00000000000..ffd57de6188 --- /dev/null +++ b/mindspore/lite/src/gllo/common/utils.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_ +#define MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_ + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "src/common/utils.h" +#include "src/gllo/common/pattern_engine.h" + +namespace mindspore { +namespace opt { + +bool AnfEqual(const BaseRef &a, const BaseRef &b); + +bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b); + +AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, + bool multigraph = false); + +void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph); + +void CheckIfAnfNodeIsNull(const AnfNodePtr &node); + +void CheckIfCNodeIsNull(const CNodePtr &node); + +void CheckIfVarIsNull(const VarPtr &var); + +void CheckInputSize(const CNodePtr &node, const int size); + +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_ + diff --git a/mindspore/lite/src/gllo/common/visit.cc b/mindspore/lite/src/gllo/common/visit.cc new file mode 100644 index 00000000000..d00744e6563 --- /dev/null +++ b/mindspore/lite/src/gllo/common/visit.cc @@ -0,0 +1,165 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include +#include +#include +#include + +#include "src/gllo/common/visit.h" +#include "src/gllo/common/pattern_engine.h" +#include "utils/any.h" +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "utils/log_adapter.h" + + +namespace mindspore { +bool CheckIfNeedExpand(const std::vector &list) { + return std::any_of(list.begin(), list.end(), [](const BaseRef &any) { return utils::isa(any); }); +} + +std::shared_ptr ExpandList(const std::vector &list) { + std::shared_ptr new_list = std::make_shared(); + for (auto &item : list) { + if (utils::isa(item)) { + const Seq &seq = utils::cast(item); + new_list->insert(new_list->end(), seq.begin(), seq.end()); + } else { + new_list->push_back(item); + } + } + return new_list; +} + +bool DefaultVisitor::Visit(const VectorRef &v_any, BaseRef *const visit_out) const { + std::vector out; + (void)std::transform(v_any.begin(), v_any.end(), std::back_inserter(out), + [this](const BaseRef &item) { return fn_(item); }); + if (visit_out != nullptr) { + *visit_out = ExpandList(out); + } + return true; +} + +bool DefaultVisitor::Visit(const BaseRef &any, BaseRef *const visit_out) const { + if (utils::isa(any)) { + return Visit(utils::cast(any), visit_out); + } else if (utils::isa(any)) { + auto nodeptr = utils::cast(any); + AnfNodePtr output; + AnfNodePtr *p_output = &output; + if (visit_out == nullptr) { + p_output = nullptr; + } + Visit(nodeptr, fn_, p_output); + if (visit_out != nullptr) { + *visit_out = output; + } + return true; + } + MS_LOG(DEBUG) << "VisitError, not support type to Visit: " + any.ToString(); + return false; +} + +void DefaultVisitor::Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const { + if (node->isa()) { + Visit(node->cast(), fn, output); + return; + } + + if (node->isa()) { + Visit(node->cast(), fn, output); + return; + } + + if (output != nullptr) { + *output = node; + } +} + +void DefaultVisitor::Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const { + // if output is nullptr, it's not required to make the new CNode node. + if (output == nullptr) { + for (auto &inp : cnode->inputs()) { + (void)fn(inp); + } + + if (cnode->func_graph() != nullptr) { + (void)fn(cnode->func_graph()); + } else { + (void)fn(cnode->func_graph_as_var()); + } + return; + } + + std::vector new_inputs; + std::vector after_cnode_fn; + std::shared_ptr out; + (void)std::transform(cnode->inputs().begin(), cnode->inputs().end(), std::back_inserter(after_cnode_fn), fn); + if (CheckIfNeedExpand(after_cnode_fn)) { + out = ExpandList(after_cnode_fn); + } + + std::vector &outs = after_cnode_fn; + if (out != nullptr) { + outs = out->elements(); + } + + for (auto &any_item : outs) { + if (!utils::isa(any_item)) { + MS_LOG(EXCEPTION) << "VisitError, fn not return the same type AnfNodePtr"; + } + new_inputs.push_back(utils::cast(any_item)); + } + + BaseRef any_fg; + AnfNodePtr new_cnode = nullptr; + if (cnode->func_graph() != nullptr) { + any_fg = fn(cnode->func_graph()); + if (!utils::isa(any_fg)) { + MS_LOG(EXCEPTION) << "VisitError, fn not return the same type FuncGraphPtr"; + } + new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); + } else { + any_fg = fn(cnode->func_graph_as_var()); + if (utils::isa(any_fg)) { + new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); + } else if (utils::isa(any_fg)) { + new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); + } else { + MS_LOG(EXCEPTION) << "VisitError, fn not return VarPtr or FuncGraphPtr"; + } + } + new_cnode->set_abstract(cnode->abstract()); + *output = new_cnode; +} + +void DefaultVisitor::Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const { + const BaseRef &value = utils::cast(fn(vnode->value())); + if (utils::isa(value)) { + if (output != nullptr) { + auto ct = NewValueNode(utils::cast(value)); + ct->set_abstract(vnode->abstract()); + *output = ct; + } + return; + } + MS_LOG(EXCEPTION) << "Visit result is not ValuePtr."; +} +} // namespace mindspore + diff --git a/mindspore/lite/src/gllo/common/visit.h b/mindspore/lite/src/gllo/common/visit.h new file mode 100644 index 00000000000..548e5e033d8 --- /dev/null +++ b/mindspore/lite/src/gllo/common/visit.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_ +#define MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_ + +#include +#include +#include +#include +#include +#include + +#include "mindspore/core/base/base.h" +#include "base/base_ref.h" + +namespace mindspore { +using VisitFn = std::function; + +class Visitor { + public: + virtual void SetFn(VisitFn fn) = 0; + virtual bool Visit(const BaseRef &e, BaseRef *out) const = 0; + virtual bool Visit(const VectorRef &e, BaseRef *out) const = 0; + virtual ~Visitor() = default; +}; + +class DefaultVisitor : public Visitor { + public: + DefaultVisitor() : fn_(nullptr) {} + ~DefaultVisitor() override = default; + void SetFn(VisitFn fn) override { fn_ = fn; }; + bool Visit(const VectorRef &e, BaseRef *out) const override; + bool Visit(const BaseRef &e, BaseRef *out) const override; + void Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const; + void Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const; + void Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const; + + VisitFn fn_; +}; + +std::shared_ptr ExpandList(const std::vector &list); +bool CheckIfNeedExpand(const std::vector &list); +} // namespace mindspore +#endif // MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_ + diff --git a/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.cc b/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.cc new file mode 100644 index 00000000000..0b20a2baf8c --- /dev/null +++ b/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/gllo/fusion/conv_biasadd_fusion.h" +#include +#include "mindspore/lite/schema/inner/model_generated.h" +#include "mindspore/lite/src/ir/primitive_t_value.h" +#include "mindspore/ccsrc/utils/utils.h" +#include "src/gllo/common/utils.h" + +namespace mindspore { +namespace opt { + +const BaseRef ConvBiasaddFusion::DefinePattern() const { + MS_LOG(DEBUG) << "Enter pattern"; + + VarPtr X = std::make_shared(); + VarPtr W = std::make_shared(); + VarPtr B = std::make_shared(); + CheckIfVarIsNull(X); + CheckIfVarIsNull(W); + CheckIfVarIsNull(B); + + auto prim1 = new schema::PrimitiveT(); + prim1->value.type = schema::PrimitiveType_BiasAdd; + auto prim11 = std::make_shared(prim1); + + auto prim2 = new schema::PrimitiveT(); + prim2->value.type = schema::PrimitiveType_Conv2D; + auto prim22 = std::make_shared(prim2); + + return VectorRef({prim11, VectorRef({prim22, X, W}), B}); +} + +const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_LOG(DEBUG) << "Enter pass process"; + CheckIfFuncGraphIsNull(func_graph); + + CheckIfAnfNodeIsNull(node); + auto cnode = node->cast(); + CheckIfCNodeIsNull(cnode); + CheckInputSize(cnode, 3); // [op, conv_node, bias_node] + + AnfNodePtr conv_node_anf = cnode->input(1); + CheckIfAnfNodeIsNull(conv_node_anf); + auto conv_node = conv_node_anf->cast(); + CheckIfCNodeIsNull(conv_node); + CheckInputSize(conv_node, 3); // [op, X, W] + + conv_node->add_input(cnode->input(2)); + + auto primitive = (lite::PrimitiveTValue *)(conv_node->input(0)->cast()->value().get()); + primitive->GetPrimitiveT()->value.AsConv2D()->hasBias = true; + + return conv_node_anf; +} + +} // namespace opt +} // namespace mindspore + diff --git a/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h b/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h new file mode 100644 index 00000000000..df0f393ad51 --- /dev/null +++ b/mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_ +#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_ + +#include "src/gllo/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConvBiasaddFusion : public PatternProcessPass { + public: + explicit ConvBiasaddFusion(bool multigraph = true) : PatternProcessPass("conv_biasadd_fusion", multigraph) {} + ~ConvBiasaddFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_ + diff --git a/mindspore/lite/src/ir/meta_tensor_extends.cc b/mindspore/lite/src/ir/meta_tensor_extends.cc new file mode 100644 index 00000000000..3e5851ba334 --- /dev/null +++ b/mindspore/lite/src/ir/meta_tensor_extends.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/meta_tensor.h" + +namespace mindspore { +namespace tensor { +abstract::AbstractBasePtr MetaTensor::ToAbstract() { + MS_LOG(ERROR) << "MetaTensor ToAbstract is not implemented"; + return nullptr; +} +TypePtr MetaTensor::Dtype() const { return nullptr; } +} // namespace tensor +} // namespace mindspore + diff --git a/mindspore/lite/src/ir/primitive_t_value.cc b/mindspore/lite/src/ir/primitive_t_value.cc new file mode 100644 index 00000000000..9c27cc66fd3 --- /dev/null +++ b/mindspore/lite/src/ir/primitive_t_value.cc @@ -0,0 +1,17 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ir/primitive_t_value.h" diff --git a/mindspore/lite/src/ir/primitive_t_value.h b/mindspore/lite/src/ir/primitive_t_value.h new file mode 100644 index 00000000000..56667890f3e --- /dev/null +++ b/mindspore/lite/src/ir/primitive_t_value.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_ +#define MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_ + +#include +#include "ir/value.h" +#include "mindspore/lite/schema/inner/model_generated.h" + +namespace mindspore::lite { + +class PrimitiveTValue : public Value { + public: + explicit PrimitiveTValue(schema::PrimitiveT *primt) : primitive(primt) {} + + ~PrimitiveTValue() override { delete this->primitive; } + + MS_DECLARE_PARENT(PrimitiveTValue, Value) + + schema::PrimitiveT *GetPrimitiveT() const { return this->primitive; } + + void SetPrimitiveT(schema::PrimitiveT *primIn) { this->primitive = primIn; } + + bool operator==(const Value &rhs) const override { + if (rhs.isa()) { + auto other_prim = static_cast(rhs); + auto a = this->primitive->value.type; + auto b = other_prim.primitive->value.type; + return a == b; + } else { + return false; + } + } + + void AddInputQuantParam(schema::QuantParamT quant_param) { + this->input_quant_param_.emplace_back(quant_param); + } + std::vector GetInputQuantParams() const { + return input_quant_param_; + } + + void AddOutputQuantParam(schema::QuantParamT quant_param) { + this->output_quant_param_.emplace_back(quant_param); + } + std::vector GetOutputQuantParams() const { + return output_quant_param_; + } + + void SetQuantType(schema::QuantType quant_type) { this->quant_type_ = quant_type; } + + schema::QuantType GetQuantType() const { return quant_type_; } + + protected: + schema::PrimitiveT *primitive = nullptr; + std::vector input_quant_param_; + std::vector output_quant_param_; + schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_ + diff --git a/mindspore/lite/src/ir/primitive_value.cc b/mindspore/lite/src/ir/primitive_value.cc new file mode 100644 index 00000000000..ebd5d4d6152 --- /dev/null +++ b/mindspore/lite/src/ir/primitive_value.cc @@ -0,0 +1,19 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ir/primitive_value.h" + + diff --git a/mindspore/lite/src/ir/primitive_value.h b/mindspore/lite/src/ir/primitive_value.h new file mode 100644 index 00000000000..66202d15e6c --- /dev/null +++ b/mindspore/lite/src/ir/primitive_value.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVE_H_ +#define MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVE_H_ + +#include "ir/value.h" +#include "src/ops/ops.h" + +namespace mindspore::lite { +class PrimitiveValue : public Value { + public: + explicit PrimitiveValue(const lite::Primitive *prim) : primitive(prim) {} + + const lite::Primitive *GetPrimitive() const { + return this->primitive; + } + MS_DECLARE_PARENT(PrimitiveValue, Value) + bool operator==(const Value &rhs) const override { + if (rhs.isa()) { + auto other_prim = static_cast(rhs); + return *this == other_prim; + } else { + return false; + } + } + + protected: + const lite::Primitive *primitive = nullptr; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVE_H_ + diff --git a/mindspore/lite/src/ir/tensor.cc b/mindspore/lite/src/ir/tensor.cc new file mode 100644 index 00000000000..24c433c7a53 --- /dev/null +++ b/mindspore/lite/src/ir/tensor.cc @@ -0,0 +1,323 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "src/ir/tensor.h" +#include "securec/include/securec.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +namespace tensor { +#define kMaxMallocSize 1024 * 1024 * 100 +Tensor::Tensor(const TypeId data_type, const std::vector &shape, const schema::Format &format, + schema::NodeType tensorType) + : MetaTensor(data_type, shape), format_(format), tensorType(tensorType) {} + +Tensor::Tensor(const Tensor &tensor) : MetaTensor(tensor) { + auto ret = CopyTensor(tensor, true); + if (0 != ret) { + MS_LOG(EXCEPTION) << "CopyTensorData error"; + } +} + +int Tensor::CopyTensorData(const Tensor &srcTensor) { + if (srcTensor.data_ == nullptr) { + MS_LOG(ERROR) << "data of srcTensor is nullptr"; + return mindspore::lite::RET_PARAM_INVALID; + } + size_t data_size = this->Size(); + MS_ASSERT(data_size == srcTensor.Size()); + if (this->data_ == nullptr) { + if (data_size > kMaxMallocSize) { + MS_LOG(ERROR) << "Malloc size is too big while coping data, " << data_size << " bytes"; + return mindspore::lite::RET_ERROR; + } + this->data_ = malloc(data_size); + } + auto ret = memcpy_s(this->data_, data_size, srcTensor.data_, srcTensor.Size()); + if (EOK != ret) { + MS_LOG(ERROR) << "memcpy_s failed : " << ret; + return mindspore::lite::RET_ERROR; + } + return 0; +} + +int Tensor::CopyTensor(const Tensor &srcTensor, bool copyData) { + this->data_type_ = srcTensor.data_type_; + this->shape_ = srcTensor.shape_; + this->tensorType = srcTensor.tensorType; + if (copyData) { + auto ret = CopyTensorData(srcTensor); + if (0 != ret) { + MS_LOG(ERROR) << "CopyTensorData error"; + return mindspore::lite::RET_ERROR; + } + } + return 0; +} + +Tensor::~Tensor() { + if (nullptr != this->data_) { + free(this->data_); + } +} + +Tensor &Tensor::operator=(const Tensor &tensor) { + if (&tensor == this) { + return *this; + } + auto ret = CopyTensor(tensor, true); + if (0 != ret) { + MS_LOG(ERROR) << "CopyTensorData error"; + MS_ASSERT(false); + } + return *this; +} + +bool Tensor::operator==(const Tensor &tensor) { + return data_ == tensor.data_ && shape_ == tensor.shape_ && data_type_ == tensor.data_type_; +} + +bool Tensor::operator==(const Value &other) const { + if (other.isa()) { + auto other_ = static_cast(other); + return *this == other_; + } else { + return false; + } +} + +int32_t Tensor::Batch() const { + if (this->shape_.size() != 4) { + MS_LOG(ERROR) << "tensor should have 4 dim"; + return -1; + } + switch (this->format_) { + case schema::Format_NHWC: + case schema::Format_NHWC4: + case schema::Format_NCHW: + case schema::Format_NC4HW4: + case schema::Format_KCHW: + case schema::Format_KHWC: + return this->shape_[0]; + case schema::Format_HWCK: + case schema::Format_CHWK: + return this->shape_[3]; + case schema::Format_HWKC: + return this->shape_[2]; + case schema::Format_CKHW: + return this->shape_[1]; + default: + MS_LOG(ERROR) << "Unsupport format: " << schema::EnumNameFormat(this->format_); + return -1; + } +} + +int32_t Tensor::Channel() const { + if (this->shape_.size() != 4) { + MS_LOG(ERROR) << "tensor should have 4 dim"; + return -1; + } + switch (this->format_) { + case schema::Format_NCHW: + case schema::Format_KCHW: + return this->shape_[1]; + case schema::Format_HWCK: + return this->shape_[2]; + case schema::Format_HWKC: + case schema::Format_NHWC: + case schema::Format_NHWC4: + case schema::Format_NC4HW4: + case schema::Format_KHWC: + return this->shape_[3]; + case schema::Format_CKHW: + case schema::Format_CHWK: + return this->shape_[0]; + default: + return -1; + } +} + +int32_t Tensor::Height() const { + if (this->shape_.size() != 4) { + MS_LOG(ERROR) << "tensor should have 4 dim"; + return -1; + } + switch (this->format_) { + case schema::Format_NCHW: + case schema::Format_KCHW: + case schema::Format_CKHW: + return this->shape_[2]; + case schema::Format_NHWC: + case schema::Format_NHWC4: + case schema::Format_NC4HW4: + case schema::Format_KHWC: + case schema::Format_CHWK: + return this->shape_[1]; + case schema::Format_HWCK: + case schema::Format_HWKC: + return this->shape_[0]; + default: + MS_LOG(ERROR) << "Unsupport format: " << schema::EnumNameFormat(this->format_); + return -1; + } +} + +int32_t Tensor::Width() const { + if (this->shape_.size() != 4) { + MS_LOG(ERROR) << "tensor should have 4 dim"; + return -1; + } + switch (this->format_) { + case schema::Format_NCHW: + case schema::Format_KCHW: + case schema::Format_CKHW: + return this->shape_[3]; + case schema::Format_KHWC: + case schema::Format_NHWC: + case schema::Format_NHWC4: + case schema::Format_NC4HW4: + case schema::Format_CHWK: + return this->shape_[2]; + case schema::Format_HWCK: + case schema::Format_HWKC: + return this->shape_[1]; + default: + return -1; + } +} + +std::string Tensor::ToString() const { + std::ostringstream oss; + oss << "Format: " << schema::EnumNameFormat(this->format_); + oss << " DataType: " << this->data_type_; + oss << " NodeType: " << schema::EnumNameNodeType(this->tensorType); + oss << " Shape:"; + for (auto &dim : this->shape()) { + oss << " " << dim; + } + oss << std::endl << "Data:"; + switch (this->data_type_) { + case kNumberTypeFloat32: { + auto data = static_cast(this->data_); + if (data == nullptr) { + return "Data of tensor is nullptr"; + } else { + for (size_t i = 0; i < 40 && i < this->ElementsNum(); i++) { + oss << " " << data[i]; + } + } + } break; + case kNumberTypeInt32: { + auto data = static_cast(this->data_); + if (data == nullptr) { + return "Data of tensor is nullptr"; + } else { + for (size_t i = 0; i < 40 && i < this->ElementsNum(); i++) { + oss << " " << data[i]; + } + } + } break; + default: + oss << "Unsupport data type to print"; + break; + } + return oss.str(); +} + +void Tensor::AddQuantParam(const tensor::QuantArg &quant_arg) { this->quant_params_.push_back(quant_arg); } + +std::vector Tensor::GetQuantParams() const { return this->quant_params_; } + +LiteTensor::LiteTensor() { this->tensor_impl_ = new tensor::Tensor(); } + +LiteTensor::LiteTensor(TypeId data_type, const std::vector &shape) { + this->tensor_impl_ = new tensor::Tensor(data_type, shape); +} + +LiteTensor::LiteTensor(tensor::Tensor *tensor_ptr) { this->tensor_impl_ = tensor_ptr; } + +TypeId LiteTensor::data_type() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->data_type(); +} + +TypeId LiteTensor::set_data_type(TypeId data_type) { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->set_data_type(data_type); +} + +std::vector LiteTensor::shape() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->shape(); +} + +size_t LiteTensor::set_shape(const std::vector &shape) { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->set_shape(shape); +} + +int LiteTensor::DimensionSize(size_t index) const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->DimensionSize(index); +} + +int LiteTensor::ElementsNum() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->ElementsNum(); +} + +std::size_t LiteTensor::hash() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->hash(); +} + +tensor::Tensor *LiteTensor::tensor() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_; +} + +size_t LiteTensor::Size() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->Size(); +} + +void *LiteTensor::MutableData() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + auto data = this->tensor_impl_->Data(); + if (nullptr == data) { + auto ret = tensor_impl_->MallocData(); + if (0 != ret) { + return nullptr; + } + } + return this->tensor_impl_->Data(); +} +LiteTensor::~LiteTensor() { delete this->tensor_impl_; } + +void LiteTensor::SetTensorImpl(tensor::Tensor *tensor) { this->tensor_impl_ = tensor; } +} // namespace tensor +} // namespace lite +namespace tensor { +MSTensor *MSTensor::CreateTensor(TypeId data_type, const std::vector &shape) { + return new mindspore::lite::tensor::LiteTensor(data_type, shape); +} +} // namespace tensor +} // namespace mindspore + diff --git a/mindspore/lite/src/ir/tensor.h b/mindspore/lite/src/ir/tensor.h new file mode 100644 index 00000000000..8a84da3d4fd --- /dev/null +++ b/mindspore/lite/src/ir/tensor.h @@ -0,0 +1,224 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_IR_TENSOR_H_ +#define MINDSPORE_LITE_SRC_IR_TENSOR_H_ + +#include +#include +#include +#include "ir/meta_tensor.h" +#include "include/ms_tensor.h" +#include "ir/dtype/type_id.h" +#include "src/runtime/allocator.h" +#include "schema/model_generated.h" + +namespace mindspore { +namespace lite { +namespace tensor { + +struct QuantArg { + double scale; + int32_t zeroPoint; +}; + +class Tensor : public mindspore::tensor::MetaTensor { + public: + Tensor() : MetaTensor() {} + + Tensor(const TypeId data_type, const std::vector &shape, const schema::Format &format = schema::Format_NHWC, + schema::NodeType tensorType = schema::NodeType_Parameter); + + Tensor(const Tensor &tensor); + + ~Tensor() override; + + int CopyTensorData(const Tensor &srcTensor); + + int CopyTensor(const Tensor &srcTensor, bool copyData = false); + + MS_DECLARE_PARENT(Tensor, MetaTensor) + + virtual Tensor &operator=(const Tensor &tensor); + + virtual bool operator==(const Tensor &tensor); + + bool operator==(const Value &other) const override; + + int32_t Batch() const; + + int32_t Channel() const; + + int32_t Height() const; + + int32_t Width() const; + + int32_t ElementsC4Num() const { return Batch() * Height() * Width() * ((Channel() + 3) / 4 * 4); } + + int DataSize() const { return this->ElementsNum(); } + + size_t Size() const { + size_t size = 0; + switch (this->data_type_) { + case kNumberTypeFloat: + case kNumberTypeFloat32: + size = sizeof(float); + break; + case kNumberTypeInt8: + size = sizeof(int8_t); + break; + case kNumberTypeUInt8: + size = sizeof(uint8_t); + break; + case kNumberTypeFloat16: + size = sizeof(int16_t); + break; + case kNumberTypeInt16: + size = sizeof(int16_t); + break; + case kNumberTypeInt32: + size = sizeof(int32_t); + break; + case kNumberTypeInt64: + size = sizeof(int64_t); + break; + case kNumberTypeUInt16: + size = sizeof(uint16_t); + break; + case kNumberTypeUInt32: + size = sizeof(uint32_t); + break; + case kNumberTypeUInt64: + size = sizeof(uint64_t); + break; + case kNumberTypeBool: + size = sizeof(bool); + break; + default: + MS_LOG(ERROR) << "Not support the type: " << this->data_type_; + return 0; + } + size *= (format_ == schema::Format_NC4HW4 || format_ == schema::Format_NHWC4) ? ElementsC4Num() + : MetaTensor::ElementsNum(); + + return size; + } + + int MallocData(mindspore::lite::Allocator *allocator = nullptr) { + if (nullptr != this->data_) { + return 0; + } + if (nullptr == allocator) { + this->data_ = malloc(this->Size()); + } else { + this->data_ = allocator->Malloc(this->Size()); + } + if (nullptr == this->data_) { + MS_LOG(ERROR) << "Malloc tensor data failed, size=" << this->Size(); + return -1; + } + + return 0; + } + + int FreeData(mindspore::lite::Allocator *allocator = nullptr) { + if (nullptr == this->data_) { + return 0; + } + if (nullptr == allocator) { + free(this->data_); + } else { + allocator->Free(this->data_); + this->data_ = nullptr; + } + + return 0; + } + + void *Data() { return data_; } + + void SetData(void *data) { this->data_ = data; } + + schema::NodeType TensorType() { return this->tensorType; } + + void SetFormat(schema::Format format) { this->format_ = format; } + + schema::Format GetFormat() { return this->format_; } + + size_t RefCount() { return this->refCount; } + + void SetRefCount(size_t refCount) { this->refCount = refCount; } + + void decRefCount() { this->refCount--; } + + std::string ToString() const override; + + void AddQuantParam(const tensor::QuantArg &quant_arg); + + std::vector GetQuantParams() const; + + protected: + void *data_ = nullptr; + void *device_data_ = nullptr; + schema::NodeType tensorType; + schema::Format format_; + size_t refCount = 0; + std::vector quant_params_; +}; + +class LiteTensor : public mindspore::tensor::MSTensor { + public: + LiteTensor(); + + LiteTensor(TypeId data_type, const std::vector &shape); + + explicit LiteTensor(tensor::Tensor *tensor_ptr); + + ~LiteTensor() override; + + TypeId data_type() const override; + + TypeId set_data_type(const TypeId data_type) override; + + std::vector shape() const override; + + size_t set_shape(const std::vector &shape) override; + + int DimensionSize(size_t index) const override; + + int ElementsNum() const override; + + std::size_t hash() const override; + + tensor::Tensor *tensor() const; + + size_t Size() const override; + + void *MutableData() const override; + + void SetTensorImpl(tensor::Tensor *tensor); + + protected: + tensor::Tensor *tensor_impl_; +}; + +using TensorPtr = std::shared_ptr; +} // namespace tensor +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_IR_TENSOR_H_ + diff --git a/mindspore/lite/src/kernel_factory.cc b/mindspore/lite/src/kernel_factory.cc new file mode 100644 index 00000000000..aa81d23c905 --- /dev/null +++ b/mindspore/lite/src/kernel_factory.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/src/kernel_factory.h" +#include "utils/log_adapter.h" +#include "src/populate_parameter.h" +#include "schema/model_generated.h" + +using mindspore::kernel::KERNEL_ARCH; +using mindspore::kernel::KernelKey; +using mindspore::kernel::LiteKernel; + +namespace mindspore::lite { +KernelFactory::KernelFactory() = default; + +KernelFactory::~KernelFactory() = default; + +KernelFactory *KernelFactory::GetInstance() { + static KernelFactory instance; + return &instance; +} + +LiteKernel *KernelFactory::GetKernel(const std::vector &inputs, + const std::vector &outputs, const lite::Primitive *primitive, + const Context *ctx, const kernel::KernelKey &key) { + MS_EXCEPTION_IF_NULL(primitive); + MS_EXCEPTION_IF_NULL(ctx); + auto parameter = kernel::PopulateParameter(primitive); + if (parameter == nullptr) { + MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(primitive->Type()); + return nullptr; + } + auto creator = KernelRegistry::GetInstance()->GetKernelCreator(key); + if (creator != nullptr) { + auto *kernel = creator(inputs, outputs, parameter, ctx, key); + if (kernel != nullptr) { + return kernel; + } else { + MS_LOG(ERROR) << "Creator kernel failed for " << schema::EnumNamePrimitiveType(key.type); + return nullptr; + } + } else { + MS_LOG(ERROR) << "Can not find OpCreator for " << schema::EnumNamePrimitiveType(key.type); + return nullptr; + } +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/kernel_factory.h b/mindspore/lite/src/kernel_factory.h new file mode 100644 index 00000000000..136008959bf --- /dev/null +++ b/mindspore/lite/src/kernel_factory.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_ +#define MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_ + +#include +#include "mindspore/lite/src/lite_kernel.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/include/context.h" +#include "mindspore/lite/src/ir/tensor.h" +#include "schema/model_generated.h" + +namespace mindspore::lite { +class KernelFactory { + public: + KernelFactory(); + virtual ~KernelFactory(); + + static KernelFactory *GetInstance(); + kernel::LiteKernel *GetKernel(const std::vector &inputs, + const std::vector &outputs, const lite::Primitive *primitive, + const Context *ctx, const kernel::KernelKey &key); +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_ + diff --git a/mindspore/lite/src/kernel_registry.cc b/mindspore/lite/src/kernel_registry.cc new file mode 100644 index 00000000000..9f48c6d5952 --- /dev/null +++ b/mindspore/lite/src/kernel_registry.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/kernel_registry.h" + +using mindspore::kernel::KernelCreator; +using mindspore::kernel::KernelKey; +using mindspore::kernel::KERNEL_ARCH; + +namespace mindspore::lite { +KernelRegistry::KernelRegistry() {} + +KernelRegistry::~KernelRegistry() {} + +KernelRegistry *KernelRegistry::GetInstance() { + static KernelRegistry instance; + return &instance; +} + +KernelCreator KernelRegistry::GetKernelCreator(const KernelKey &desc) { + auto it = creators.find(desc); + if (it != creators.end()) { + return it->second; + } + + // if not find, use cpu kernel + KernelKey cpuDesc {kernel::KERNEL_ARCH::kCPU, desc.type}; + it = creators.find(cpuDesc); + if (it != creators.end()) { + return it->second; + } + return nullptr; +} + +void KernelRegistry::RegKernel(const KernelKey desc, KernelCreator creator) { creators[desc] = creator; } + +void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const schema::PrimitiveType type, KernelCreator creator) { + KernelKey desc = {arch, type}; + creators[desc] = creator; +} + +bool KernelRegistry::Merge(const std::unordered_map &newCreators) { return false; } + +const std::map &KernelRegistry::GetKernelCreators() { return creators; } +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/kernel_registry.h b/mindspore/lite/src/kernel_registry.h new file mode 100644 index 00000000000..7873ac2a751 --- /dev/null +++ b/mindspore/lite/src/kernel_registry.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_ +#define MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_ + +#include +#include +#include +#include "src/lite_kernel.h" +#include "schema/model_generated.h" + +namespace mindspore::lite { +class KernelRegistry { + public: + KernelRegistry(); + virtual ~KernelRegistry(); + + static KernelRegistry *GetInstance(); + virtual kernel::KernelCreator GetKernelCreator(const kernel::KernelKey &desc); + + const std::map &GetKernelCreators(); + + void RegKernel(const kernel::KernelKey desc, kernel::KernelCreator creator); + void RegKernel(const kernel::KERNEL_ARCH arch, const schema::PrimitiveType type, kernel::KernelCreator creator); + bool Merge(const std::unordered_map &newCreators); + + protected: + std::map creators; +}; + +class KernelRegistrar { + public: + KernelRegistrar(const kernel::KernelKey &desc, kernel::KernelCreator creator) { + KernelRegistry::GetInstance()->RegKernel(desc, creator); + } + + KernelRegistrar(const kernel::KERNEL_ARCH arch, const schema::PrimitiveType type, kernel::KernelCreator creator) { + KernelRegistry::GetInstance()->RegKernel(arch, type, creator); + } +}; + +#define REG_KERNEL(arch, type, kernelCreater) \ + static KernelRegistrar g_##arch##type##kernelReg(arch, type, kernelCreater); +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_ + diff --git a/mindspore/lite/src/lite_kernel.cc b/mindspore/lite/src/lite_kernel.cc new file mode 100644 index 00000000000..43d2b9d3883 --- /dev/null +++ b/mindspore/lite/src/lite_kernel.cc @@ -0,0 +1,144 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/lite_kernel.h" +#include +#include "src/common/utils.h" + +namespace mindspore::kernel { +void LiteKernel::InitOutTensorRefCount() { + for (auto *tensor : this->outputs_) { + tensor->SetRefCount(this->out_kernel_.size()); + } +} + +int LiteKernel::DecOutTensorRefCount(lite::Allocator *allocator) { + for (auto *tensor : this->outputs_) { + tensor->decRefCount(); + if (0 >= tensor->RefCount()) { + auto ret = tensor->FreeData(allocator); + if (0 != ret) { + MS_LOG(ERROR) << "Free tensor data failed"; + return ret; + } + } + } + return 0; +} + +std::vector LiteKernelUtil::SubgraphInputKernels( + const std::vector &kernels) { + std::vector input_kernels; + for (const auto kernel : kernels) { + for (auto input : kernel->GetInKernels()) { + auto iter = std::find(kernels.begin(), kernels.end(), input); + if (iter == kernels.end()) { + input_kernels.emplace_back(input); + } + } + } + return input_kernels; +} + +std::vector LiteKernelUtil::SubgraphOutputKernels( + const std::vector &kernels) { + std::vector output_kernels; + for (const auto kernel : kernels) { + for (const auto output : kernel->GetOutKernels()) { + auto iter = std::find(kernels.begin(), kernels.end(), output); + if (iter == kernels.end()) { + output_kernels.emplace_back(output); + } + } + } + return output_kernels; +} + +std::vector LiteKernelUtil::SubgraphInputTensors( + const std::vector &kernels) { + std::vector input_tensors; + std::vector all_output_tensors; + for (const auto &kernel : kernels) { + all_output_tensors.insert(all_output_tensors.end(), kernel->GetOutputs().begin(), kernel->GetOutputs().end()); + } + std::sort(all_output_tensors.begin(), all_output_tensors.end()); + auto end_iter = std::unique(all_output_tensors.begin(), all_output_tensors.end()); + all_output_tensors.erase(end_iter, all_output_tensors.end()); + + std::vector input_kernels = SubgraphInputKernels(kernels); + for (const auto &kernel : input_kernels) { + for (const auto &tensor : kernel->GetInputs()) { + auto iter = std::find(all_output_tensors.begin(), all_output_tensors.end(), tensor); + if (iter == all_output_tensors.end() && tensor->Data() == nullptr) { + input_tensors.emplace_back(tensor); + } + } + } + return input_tensors; +} + +std::vector LiteKernelUtil::SubgraphOutputTensors( + const std::vector &kernels) { + std::vector output_tensors; + std::vector all_input_tensors; + for (const auto &kernel : kernels) { + all_input_tensors.insert(all_input_tensors.end(), kernel->GetInputs().begin(), kernel->GetInputs().end()); + } + std::sort(all_input_tensors.begin(), all_input_tensors.end()); + auto end_iter = std::unique(all_input_tensors.begin(), all_input_tensors.end()); + all_input_tensors.erase(end_iter, all_input_tensors.end()); + + std::vector output_kernels = SubgraphOutputKernels(kernels); + for (const auto &kernel : output_kernels) { + for (const auto &tensor : kernel->GetOutputs()) { + auto iter = std::find(all_input_tensors.begin(), all_input_tensors.end(), tensor); + if (iter == all_input_tensors.end()) { + output_tensors.emplace_back(tensor); + } + } + } + return output_tensors; +} + +void LiteKernelUtil::TopologicalSortKernels(std::vector &kernels) { + for (auto *kernel : kernels) { + for (auto *search_kernel : kernels) { + if (search_kernel == kernel) { + continue; + } + for (auto *tensor : kernel->GetInputs()) { + if (lite::IsContain(search_kernel->GetOutputs(), tensor)) { + kernel->AddInKernel(search_kernel); + } + } + for (auto *tensor : kernel->GetOutputs()) { + if (lite::IsContain(search_kernel->GetInputs(), tensor)) { + kernel->AddOutKernel(search_kernel); + } + } + } + } +} + +void LiteKernelUtil::InitTensorRefCount(std::vector &kernels) { + for (auto *kernel : kernels) { + kernel->InitOutTensorRefCount(); + } +} + +int LiteKernelUtil::SetInput(LiteKernel &kernelMod, std::vector inputs) { return -1; } +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h new file mode 100644 index 00000000000..ea6ba756483 --- /dev/null +++ b/mindspore/lite/src/lite_kernel.h @@ -0,0 +1,182 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_LITE_KERNEL_H_ +#define MINDSPORE_LITE_SRC_LITE_KERNEL_H_ +#include +#include +#ifdef ENABLE_FP16 +#include +#endif +#include "src/runtime/kernel/arm/opclib/op_base.h" +// #include "backend/kernel_compiler/kernel.h" +#include "include/context.h" +#include "src/ir/tensor.h" +#include "src/ops/ops.h" + +#ifdef ENABLE_FP16 +using FLOAT_t = float16_t; +#else +using FLOAT_t = float; +#endif + +// using mindspore::kernel::AddressPtr; +namespace mindspore::kernel { +enum KERNEL_ARCH { kCPU, kGPU, kNPU, kInferShape }; +struct KernelKey { + KERNEL_ARCH arch; + schema::PrimitiveType type; + + bool operator<(const KernelKey &dst) const { + if (arch != dst.arch) { + return arch < dst.arch; + } else { + return type < dst.type; + } + } +}; + +class LiteKernel; +struct CallBackParam { + std::string name_callback_aram; +}; + +using KernelCallBack = std::function inputs, + std::vector outputs, const CallBackParam &opInfo)>; + +// class LiteKernel : public KernelMod { +class LiteKernel { + public: + LiteKernel() = default; + explicit LiteKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : opParameter(parameter), inputs_(inputs), outputs_(outputs) { + this->in_kernel_.clear(); + this->out_kernel_.clear(); + } + + virtual ~LiteKernel() { delete opParameter; } + + // bool Launch(const std::vector &inputs, const std::vector &workspace, + // const std::vector &outputs, void *stream_ptr) override { + // return false; + // }; + // + // const std::vector &GetInputSizeList() const override { return {}; } + // + // const std::vector &GetOutputSizeList() const override { return {}; } + // + // const std::vector &GetWorkspaceSizeList() const override { return {}; } + + virtual int Prepare() { return -1; } + virtual int Init() { return -1; } + virtual int ReSize() { return -1; } + virtual int Run() { return -1; } + + std::string Name() { return this->name; } + + void set_name(const std::string &name) { this->name = name; } + + schema::PrimitiveType type() { return (schema::PrimitiveType)this->opParameter->type_; } + + std::string type_str() { return schema::EnumNamePrimitiveType((schema::PrimitiveType)this->opParameter->type_); } + + void SetInputs(const std::vector &inputs) { this->inputs_ = inputs; } + + void SetOutputs(const std::vector &outputs) { this->outputs_ = outputs; } + + std::vector &GetInputs() { return this->inputs_; } + + std::vector &GetOutputs() { return this->outputs_; } + + void AddInKernel(LiteKernel *kernel) { this->in_kernel_.emplace_back(kernel); } + + void AddOutKernel(LiteKernel *kernel) { this->out_kernel_.emplace_back(kernel); } + + std::vector &GetInKernels() { return this->in_kernel_; } + + std::vector &GetOutKernels() { return this->out_kernel_; } + + void InitOutTensorRefCount(); + + int DecOutTensorRefCount(lite::Allocator *allocator = nullptr); + + const KernelKey Desc() const { return desc; } + + void set_desc(const KernelKey kernel_key) { desc = kernel_key; } + + protected: + KernelKey desc; + std::string name; + OpParameter *opParameter = nullptr; + // tensor will free in ~lite_session() + std::vector inputs_; + std::vector outputs_; + std::vector in_kernel_; + std::vector out_kernel_; +}; + +class SubGraphKernel : public LiteKernel { + public: + explicit SubGraphKernel(const std::vector &inputs, + const std::vector &outputs, + const std::vector &inKernels, + const std::vector &outKernels, + const std::vector &nodes) + : LiteKernel(nullptr, inputs, outputs), + inputs_(inputs), + outputs_(outputs), + inkernels_(inKernels), + outkernels_(outKernels), + nodes_(nodes) {} + + virtual int Init() { return -1; } + virtual int InferShape() { return -1; } + virtual int ReSize() { return -1; } + virtual int Run() { return -1; } + + protected: + std::vector inputs_; + std::vector outputs_; + std::vector inkernels_; + std::vector outkernels_; + std::vector nodes_; +}; + +typedef LiteKernel *(*KernelCreator)(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::Context *ctx, const KernelKey &desc); + +class LiteKernelUtil { + public: + static void TopologicalSortKernels(std::vector &kernels); + + static std::vector SubgraphInputKernels(const std::vector &kernels); + + static std::vector SubgraphOutputKernels(const std::vector &kernels); + + static std::vector SubgraphInputTensors(const std::vector &kernels); + + static std::vector SubgraphOutputTensors(const std::vector &kernels); + + static void InitTensorRefCount(std::vector &kernels); + + static int SetInput(LiteKernel &kernelMod, std::vector inputs); +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_LITE_KERNEL_H_ + diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc new file mode 100644 index 00000000000..87cf187eeb7 --- /dev/null +++ b/mindspore/lite/src/lite_session.cc @@ -0,0 +1,283 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "include/errorcode.h" +#include "src/lite_session.h" +#include "utils/log_adapter.h" +#include "src/scheduler.h" +#include "src/runtime/runtime_api.h" +#include "src/runtime/allocator.h" +#include "src/executor.h" +#include "src/common/utils.h" +#include "src/common/graph_util.h" +#if SUPPORT_GPU +#include "src/runtime/opencl/opencl_runtime.h" +#endif + +namespace mindspore { +namespace lite { +int LiteSession::ConvertTensors(const lite::Model *model) { + MS_EXCEPTION_IF_NULL(model); + auto meta_graph = model->GetMetaGraph(); + MS_EXCEPTION_IF_NULL(meta_graph); + uint32_t tensorCount = meta_graph->allTensors()->size(); + for (uint32_t i = 0; i < tensorCount; i++) { + auto *srcTensor = meta_graph->allTensors()->GetAs(i); + if (srcTensor == nullptr) { + MS_LOG(ERROR) << i << "th tensor in meta_graph is nullptr"; + return RET_NULL_PTR; + } + std::vector shape; + if (srcTensor->dims() == nullptr) { + MS_LOG(DEBUG) << "Dims of " << i << "th tensor is nullptr"; + } else { + if (srcTensor->nodeType() == schema::NodeType_ValueNode) { + for (size_t j = 0; j < srcTensor->dims()->size(); j++) { + shape.push_back(srcTensor->dims()->data()[j]); + } + } + } + int dataType = srcTensor->dataType(); + auto *dstTensor = new tensor::Tensor(TypeId(dataType), shape, srcTensor->format(), srcTensor->nodeType()); + if (srcTensor->nodeType() == schema::NodeType_ValueNode && srcTensor->data() != nullptr && + srcTensor->data()->size() > 0) { + if (shape.empty()) { + shape.push_back(1); + } + MS_ASSERT(dstTensor != nullptr); + MS_ASSERT(dstTensor->Size() == srcTensor->data()->size()); + // no copy data, do copy when call LiteKernel::Init + dstTensor->SetData(const_cast(srcTensor->data()->data())); + } + this->tensors.emplace_back(dstTensor); + } + return RET_OK; +} + +int LiteSession::ConvertKernels(const lite::Model *model, Context *context) { + // MS_EXCEPTION_IF_NULL(model); + // auto meta_graph = model->GetMetaGraph(); + // MS_EXCEPTION_IF_NULL(meta_graph); + // uint32_t kernelCount = meta_graph->nodes()->size(); + // for (uint32_t i = 0; i < kernelCount; i++) { + // auto cNode = meta_graph->nodes()->GetAs(i); + // std::vector inputs; + // std::vector outputs; + // auto inIndexes = cNode->inputIndex(); + // for (size_t j = 0; j < inIndexes->size(); j++) { + // inputs.emplace_back(this->tensors.at(size_t(inIndexes->GetAs(j)))); + // } + // auto outIndexes = cNode->outputIndex(); + // for (size_t j = 0; j < outIndexes->size(); j++) { + // outputs.emplace_back(this->tensors.at(size_t(outIndexes->GetAs(j)))); + // } + // const auto *primitive = model->GetOp(cNode->name()->str()); + // if (primitive == nullptr) { + // MS_LOG(ERROR) << "Op " << cNode->name()->str() << " should exist in model"; + // return RET_ERROR; + // } + // auto ret = primitive->InferShape(inputs, outputs); + // if (0 != ret) { + // MS_LOG(ERROR) << "InferShape failed, node : " << cNode->name()->str(); + // return ret; + // } + // auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(inputs, outputs, cNode, context); + // if (nullptr == kernel) { + // MS_LOG(ERROR) << "Create kernel return nullptr, name: " << cNode->name()->str() + // << ", type: " << schema::EnumNamePrimitiveType(cNode->primitive()->value_type()); + // return RET_ERROR; + // } + // kernels.emplace_back(kernel); + // } + return RET_OK; +} + +void LiteSession::InitGraphInOutTensor(const lite::Model *model) { + auto meta_graph = model->GetMetaGraph(); + MS_ASSERT(this->input_map.empty()); + MS_ASSERT(meta_graph != nullptr); + auto graph_input_node_indexes = GetGraphInputNodes(meta_graph); + for (auto in_node_index : graph_input_node_indexes) { + auto *in_node = meta_graph->nodes()->GetAs(in_node_index); + MS_ASSERT(nullptr != in_node); + MS_ASSERT(this->input_map.find(in_node->name()->str()) == this->input_map.end()); + for (size_t i = 0; i < in_node->inputIndex()->size(); i++) { + auto in_tensor_index = size_t(in_node->inputIndex()->GetAs(i)); + bool is_graph_input = false; + for (size_t j = 0; j < meta_graph->inputIndex()->size(); j++) { + if (in_tensor_index == size_t(meta_graph->inputIndex()->GetAs(j))) { + is_graph_input = true; + break; + } + } + if (!is_graph_input) { + continue; + } + MS_ASSERT(in_tensor_index < this->tensors.size()); + auto *in_tensor = this->tensors.at(in_tensor_index); + MS_ASSERT(in_tensor != nullptr); + auto *ms_tensor = new tensor::LiteTensor(in_tensor); + MS_ASSERT(nullptr != ms_tensor); + this->input_map[in_node->name()->str()].emplace_back(ms_tensor); + } + } + + auto graph_output_node_indexes = GetGraphOutputNodes(meta_graph); + for (auto out_node_index : graph_output_node_indexes) { + auto *out_node = meta_graph->nodes()->GetAs(out_node_index); + MS_ASSERT(nullptr != out_node); + MS_ASSERT(this->output_map.find(out_node->name()->str()) == this->output_map.end()); + for (size_t i = 0; i < out_node->outputIndex()->size(); i++) { + auto out_tensor_index = size_t(out_node->outputIndex()->GetAs(i)); + bool is_graph_output = false; + for (size_t j = 0; j < meta_graph->outputIndex()->size(); j++) { + if (out_tensor_index == size_t(meta_graph->outputIndex()->GetAs(j))) { + is_graph_output = true; + break; + } + } + if (!is_graph_output) { + continue; + } + MS_ASSERT(out_tensor_index < this->tensors.size()); + auto *out_tensor = this->tensors.at(out_tensor_index); + MS_ASSERT(out_tensor != nullptr); + auto *ms_tensor = new tensor::LiteTensor(out_tensor); + MS_ASSERT(nullptr != ms_tensor); + this->output_map[out_node->name()->str()].emplace_back(ms_tensor); + } + } +} + +int LiteSession::CompileGraph(Model *model) { + // model.MetaGraph ==> kernels + if (model == nullptr) { + MS_LOG(ERROR) << "The input model is nullptr."; + return RET_PARAM_INVALID; + } + + auto ret = ConvertTensors(model); + if (0 != ret) { + MS_LOG(ERROR) << "ConvertTensors failed: " << ret; + return ret; + } + + InitGraphInOutTensor(model); + + // scheduler kernels + Scheduler scheduler(context); + ret = scheduler.Schedule(model, &tensors, &kernels); + if (0 != ret) { + MS_LOG(ERROR) << "Schedule kernels failed: " << ret; + return ret; + } + + return RET_OK; +} + +std::vector LiteSession::GetInputs() { + std::vector ret; + for (auto &iter : this->input_map) { + auto &node_input_tensors = iter.second; + for (auto tensor : node_input_tensors) { + if (!IsContain(ret, tensor)) { + ret.emplace_back(tensor); + } + } + } + return ret; +} + +int LiteSession::RunGraph() { + MS_EXCEPTION_IF_NULL(this->context); + Executor executor; + return executor.Run(this->inputs, this->outputs, this->kernels, this->context->allocator.get()); +} + +int LiteSession::RunGraph(const kernel::KernelCallBack &before, const kernel::KernelCallBack &after) { + MS_EXCEPTION_IF_NULL(this->context); + Executor executor; + return executor.Run(this->inputs, this->outputs, this->kernels, this->context->allocator.get(), before, after); +} + +std::vector LiteSession::GetOutputs() { + std::vector ret; + for (auto &iter : this->output_map) { + auto &node_output_tensors = iter.second; + for (auto tensor : node_output_tensors) { + if (!IsContain(ret, tensor)) { + ret.emplace_back(tensor); + } + } + } + return ret; +} + +void LiteSession::Init(Context *context) { + MS_EXCEPTION_IF_NULL(context); + this->context = new Context; + this->context->cpuBindMode = context->cpuBindMode; + this->context->threadNum = context->threadNum; + this->context->deviceCtx.type = context->deviceCtx.type; + this->context->allocator = std::make_shared(); + ConfigThreadPool(context->cpuBindMode, context->threadNum); + +#if SUPPORT_GPU + if (context->deviceCtx.type == DT_GPU) { + auto opencl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + opencl_runtime->Init(); + } +#endif +} + +void LiteSession::BindThread(bool ifBind) { + if (this->context->cpuBindMode != NO_BIND) { + DoAllThreadBind(ifBind, static_cast(this->context->cpuBindMode)); + } +} + +LiteSession::~LiteSession() { + for (auto *tensor : tensors) { + delete tensor; + } + for (auto *input : inputs) { + ((tensor::LiteTensor *)input)->SetTensorImpl(nullptr); + delete input; + } + for (auto *output : outputs) { + ((tensor::LiteTensor *)output)->SetTensorImpl(nullptr); + delete output; + } + for (auto *kernel : kernels) { + delete kernel; + } +} +std::vector LiteSession::GetInputsByName(std::string name) { + return input_map[name]; +} +std::vector LiteSession::GetOutputsByName(std::string name) { + return output_map[name]; +} +} // namespace lite + +session::LiteSession *session::LiteSession::CreateSession(lite::Context *context) { + auto session = new lite::LiteSession(); + session->Init(context); + return session; +} +} // namespace mindspore + diff --git a/mindspore/lite/src/lite_session.h b/mindspore/lite/src/lite_session.h new file mode 100644 index 00000000000..83c44f58e9d --- /dev/null +++ b/mindspore/lite/src/lite_session.h @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_LITE_SESSION_H_ +#define MINDSPORE_LITE_SRC_LITE_SESSION_H_ + +#include +#include +#include +#include +#include "include/ms_tensor.h" +#include "include/lite_session.h" +#include "include/model.h" +#include "include/context.h" +#include "src/lite_kernel.h" +#include "schema/model_generated.h" + +namespace mindspore { +namespace lite { +class LiteSession : public session::LiteSession { + public: + LiteSession() = default; + + ~LiteSession() override; + + void Init(Context *context); + + void BindThread(bool ifBind) override; + + int CompileGraph(Model *model) override; + + std::vector GetInputs() override; + + std::vector GetInputsByName(std::string name) override; + + int RunGraph() override; + + int RunGraph(const kernel::KernelCallBack &before = nullptr, const kernel::KernelCallBack &after = nullptr); + + std::vector GetOutputs() override; + + std::vector GetOutputsByName(std::string name) override; + + protected: + int ConvertTensors(const lite::Model *model); + int ConvertKernels(const lite::Model *model, Context *context); + void InitGraphInOutTensor(const lite::Model *model); + + protected: + Context *context = nullptr; + std::vector kernels; + std::vector tensors; + // graph input tensors + std::vector inputs; + // graph output tensors + std::vector outputs; + // graph input node name -- input tensors + std::unordered_map> input_map; + // graph output node name -- output tensors + std::unordered_map> output_map; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_LITE_SESSION_H_ + diff --git a/mindspore/lite/src/model.cc b/mindspore/lite/src/model.cc new file mode 100644 index 00000000000..b5c5ee06a3e --- /dev/null +++ b/mindspore/lite/src/model.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef SUPPORT_TRAIN +#include "src/train/model_impl.h" +#else +#include "src/model_impl.h" +#endif +#include "include/model.h" +#include "utils/log_adapter.h" + +namespace mindspore::lite { + +std::shared_ptr Model::Import(const char *model_buf, size_t size) { + auto model = std::make_shared(); + model->modelImpl = ModelImpl::Import(model_buf, size); + return model; +} + +lite::Primitive *Model::GetOp(const std::string &name) const { + MS_EXCEPTION_IF_NULL(modelImpl); + return const_cast(modelImpl->GetOp(name)); +} + +void Model::FreeMetaGraph() { + MS_EXCEPTION_IF_NULL(modelImpl); + return modelImpl->FreeMetaGraph(); +} + +const schema::MetaGraph *Model::GetMetaGraph() const { + MS_EXCEPTION_IF_NULL(modelImpl); + return modelImpl->GetMetaGraph(); +} + +std::shared_ptr Model::GetModelImpl() { + MS_EXCEPTION_IF_NULL(modelImpl); + return this->modelImpl; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/model_impl.cc b/mindspore/lite/src/model_impl.cc new file mode 100644 index 00000000000..c9265e6c2d2 --- /dev/null +++ b/mindspore/lite/src/model_impl.cc @@ -0,0 +1,187 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "src/model_impl.h" +#include "utils/log_adapter.h" + +namespace mindspore::lite { +std::shared_ptr ModelImpl::Import(const char *model_buf, size_t size) { + MS_EXCEPTION_IF_NULL(model_buf); + flatbuffers::Verifier verify((const uint8_t *)model_buf, size); + if (!schema::VerifyMetaGraphBuffer(verify)) { + MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; + return nullptr; + } + auto *inner_model_buf = new (std::nothrow) char[size]; + if (inner_model_buf == nullptr) { + MS_LOG(ERROR) << "new model buf fail."; + return nullptr; + } + memcpy(inner_model_buf, model_buf, size); + auto model = std::make_shared(inner_model_buf, size); + auto ret = model->BuildOps(); + if (0 != ret) { + MS_LOG(ERROR) << "BuildOps failed"; + return nullptr; + } + return model; +} + +lite::Primitive *ModelImpl::GetOp(const std::string &name) const { + auto iter = ops.find(name); + if (iter == ops.end()) { + return nullptr; + } else { + return iter->second; + } +} + +ModelImpl::~ModelImpl() { + delete (this->model_buf_); + for (auto iter : ops) { + delete (iter.second); + } + ops.clear(); +} + +void ModelImpl::FreeMetaGraph() { + delete this->model_buf_; + model_buf_ = nullptr; +} + +const schema::MetaGraph *ModelImpl::GetMetaGraph() const { return this->meta_graph; } + +lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { + MS_EXCEPTION_IF_NULL(srcPrim); + auto op_type = srcPrim->value_type(); + switch (op_type) { + case schema::PrimitiveType_SoftMax: + return new lite::SoftMax(const_cast(srcPrim)); + case schema::PrimitiveType_Activation: + return new lite::Activation(const_cast(srcPrim)); + case schema::PrimitiveType_Conv2D: + return new lite::Conv2D(const_cast(srcPrim)); + case schema::PrimitiveType_Reduce: + return new lite::Reduce(const_cast(srcPrim)); + case schema::PrimitiveType_Pooling: + return new lite::Pooling(const_cast(srcPrim)); + case schema::PrimitiveType_DepthwiseConv2D: + return new lite::DepthwiseConv2D(const_cast(srcPrim)); + case schema::PrimitiveType_FusedBatchNorm: + return new lite::FusedBatchNorm(const_cast(srcPrim)); + case schema::PrimitiveType_CaffeBatchNorm: + return new lite::CaffeBatchNorm(const_cast(srcPrim)); + case schema::PrimitiveType_FullConnection: + return new lite::FullConnection(const_cast(srcPrim)); + case schema::PrimitiveType_Power: + return new lite::Power(const_cast(srcPrim)); + case schema::PrimitiveType_Range: + return new lite::Range(const_cast(srcPrim)); + case schema::PrimitiveType_Mul: + return new lite::Mul(const_cast(srcPrim)); + case schema::PrimitiveType_Add: + return new lite::Add(const_cast(srcPrim)); + case schema::PrimitiveType_Sub: + return new lite::Sub(const_cast(srcPrim)); + case schema::PrimitiveType_Div: + return new lite::Div(const_cast(srcPrim)); + case schema::PrimitiveType_BiasAdd: + return new lite::BiasAdd(const_cast(srcPrim)); + case schema::PrimitiveType_ExpandDims: + return new lite::ExpandDims(const_cast(srcPrim)); + case schema::PrimitiveType_ArgMax: + return new lite::ArgMax(const_cast(srcPrim)); + case schema::PrimitiveType_ArgMin: + return new lite::ArgMin(const_cast(srcPrim)); + case schema::PrimitiveType_Cast: + return new lite::Cast(const_cast(srcPrim)); + case schema::PrimitiveType_Reshape: + return new lite::Reshape(const_cast(srcPrim)); + case schema::PrimitiveType_Scale: + return new lite::Scale(const_cast(srcPrim)); + case schema::PrimitiveType_Eltwise: + return new lite::Eltwise(const_cast(srcPrim)); + case schema::PrimitiveType_Ceil: + return new lite::Ceil(const_cast(srcPrim)); + case schema::PrimitiveType_Concat: + return new lite::Concat(const_cast(srcPrim)); + case schema::PrimitiveType_Fill: + return new lite::Fill(const_cast(srcPrim)); + case schema::PrimitiveType_Transpose: + return new lite::Transpose(const_cast(srcPrim)); + case schema::PrimitiveType_Slice: + return new lite::Slice(const_cast(srcPrim)); + case schema::PrimitiveType_Nchw2Nhwc: + return new lite::Nchw2Nhwc(const_cast(srcPrim)); + case schema::PrimitiveType_Nhwc2Nchw: + return new lite::Nhwc2Nchw(const_cast(srcPrim)); + case schema::PrimitiveType_Flatten: + return new lite::Flatten(const_cast(srcPrim)); + case schema::PrimitiveType_MatMul: + return new lite::MatMul(const_cast(srcPrim)); + default: + break; + } + return nullptr; +} + +int ModelImpl::BuildOps() { + if (this->meta_graph == nullptr) { + MS_LOG(ERROR) << "mete_graph is nullptr"; + return -1; + } + MS_EXCEPTION_IF_NULL(meta_graph->nodes()); + for (size_t i = 0; i < meta_graph->nodes()->size(); i++) { + auto cNode = meta_graph->nodes()->GetAs(i); + auto name = cNode->name()->str(); + auto srcPrim = cNode->primitive(); + + this->ops[name] = CopyPrimitive(srcPrim); + // flatbuffers::FlatBufferBuilder fbb(1024); + // schema::Conv2DBuilder conv2DBuilder(fbb); + // conv2DBuilder.add_padMode(srcPrim->value_as_Conv2D()->padMode()); + // conv2DBuilder.add_channelOut(srcPrim->value_as_Conv2D()->channelOut()); + // conv2DBuilder.add_channelIn(srcPrim->value_as_Conv2D()->channelIn()); + // conv2DBuilder.add_strideH(srcPrim->value_as_Conv2D()->strideH()); + // conv2DBuilder.add_strideW(srcPrim->value_as_Conv2D()->strideW()); + // conv2DBuilder.add_dilateH(srcPrim->value_as_Conv2D()->dilateH()); + // conv2DBuilder.add_dilateW(srcPrim->value_as_Conv2D()->dilateW()); + // conv2DBuilder.add_kernelH(srcPrim->value_as_Conv2D()->kernelH()); + // conv2DBuilder.add_kernelW(srcPrim->value_as_Conv2D()->kernelW()); + // conv2DBuilder.add_padUp(srcPrim->value_as_Conv2D()->padUp()); + // conv2DBuilder.add_padDown(srcPrim->value_as_Conv2D()->padDown()); + // conv2DBuilder.add_padLeft(srcPrim->value_as_Conv2D()->padLeft()); + // conv2DBuilder.add_padRight(srcPrim->value_as_Conv2D()->padRight()); + // conv2DBuilder.add_format(srcPrim->value_as_Conv2D()->format()); + // conv2DBuilder.add_group(srcPrim->value_as_Conv2D()->group()); + // conv2DBuilder.add_activationType(srcPrim->value_as_Conv2D()->activationType()); + // schema::PrimitiveBuilder primBuilder(fbb); + // primBuilder.add_value_type(srcPrim->value_type()); + // primBuilder.add_value(conv2DBuilder.Finish()); + // + // fbb.Finish(conv2DBuilder.Finish()); + // auto buf = fbb.GetBufferPointer(); + // auto conv2D = flatbuffers::GetRoot(buf); + // fbb.Clear(); + // + // return const_cast(opDef); + } + return 0; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/model_impl.h b/mindspore/lite/src/model_impl.h new file mode 100644 index 00000000000..14e0a1ccb96 --- /dev/null +++ b/mindspore/lite/src/model_impl.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_MODEL_IMPL_H_ +#define MINDSPORE_LITE_SRC_MODEL_IMPL_H_ + +#include +#include +#include +#include "schema/model_generated.h" +#include "src/ops/ops.h" + +namespace mindspore { +namespace lite { +class ModelImpl { + public: + static std::shared_ptr Import(const char *model_buf, size_t size); + ModelImpl() = default; + explicit ModelImpl(const char *model_buf, size_t size) : model_buf_(model_buf), buf_size_(size) { + meta_graph = schema::GetMetaGraph(model_buf); + } + virtual ~ModelImpl(); + lite::Primitive *GetOp(const std::string &name) const; + const schema::MetaGraph *GetMetaGraph() const; + void FreeMetaGraph(); + int BuildOps(); + + protected: + lite::Primitive *CopyPrimitive(const schema::Primitive *srcPrim); + + protected: + const char *model_buf_; + size_t buf_size_; + const schema::MetaGraph *meta_graph = nullptr; + std::map ops; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_INCLUDE_MODEL_H_ + diff --git a/mindspore/lite/src/ops/CMakeLists.txt b/mindspore/lite/src/ops/CMakeLists.txt new file mode 100644 index 00000000000..c468336fcae --- /dev/null +++ b/mindspore/lite/src/ops/CMakeLists.txt @@ -0,0 +1,3 @@ +file(GLOB_RECURSE OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc) + +add_library(ops_mid_ OBJECT ${OPS_SRC}) \ No newline at end of file diff --git a/mindspore/lite/src/ops/addn.cc b/mindspore/lite/src/ops/addn.cc new file mode 100644 index 00000000000..6b9252cd39f --- /dev/null +++ b/mindspore/lite/src/ops/addn.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int AddN::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + if (inputs_.size() < kDoubleNum) { + MS_LOG(ERROR) << "input size is error"; + return RET_INPUT_TENSOR_ERROR; + } + for (int i = 1; i < inputs_.size(); ++i) { + if (inputs_.at(i)->shape() != inputs_.at(0)->shape()) { + MS_LOG(ERROR) << "AddN inputs shape is not equal!"; + return RET_INPUT_TENSOR_ERROR; + } + } + output->set_shape(input->shape()); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/argmax.cc b/mindspore/lite/src/ops/argmax.cc new file mode 100644 index 00000000000..af94e597e7e --- /dev/null +++ b/mindspore/lite/src/ops/argmax.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int ArgMax::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "tensor number is error."; + } + auto argmax_prim = this->primitive->value_as_ArgMax(); + + std::vector output_shape(input->shape()); + auto input_shape_size = input->shape().size(); + int axis = argmax_prim->axis() < 0 ? argmax_prim->axis() + input_shape_size : argmax_prim->axis(); + if (axis >= input_shape_size || axis < 0) { + MS_LOG(ERROR) << "Invalid axis " << argmax_prim->axis() << ", input shape size: " << input_shape_size; + return RET_PARAM_INVALID; + } + output_shape.erase(output_shape.begin() + axis); + + output->set_shape(output_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/argmin.cc b/mindspore/lite/src/ops/argmin.cc new file mode 100644 index 00000000000..2323af643f4 --- /dev/null +++ b/mindspore/lite/src/ops/argmin.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int ArgMin::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "tensor number is error."; + } + auto argmin_prim = this->primitive->value_as_ArgMin(); + auto input_shape_size = input->shape().size(); + int axis = argmin_prim->axis() < 0 ? argmin_prim->axis() + input_shape_size : argmin_prim->axis(); + if (axis >= input_shape_size || axis < 0) { + MS_LOG(ERROR) << "Invalid axis " << argmin_prim->axis() << ", input shape size: " << input_shape_size; + return RET_PARAM_INVALID; + } + std::vector output_shape(input->shape()); + output_shape.erase(output_shape.begin() + axis); + + output->set_shape(output_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/arithmetic.cc b/mindspore/lite/src/ops/arithmetic.cc new file mode 100644 index 00000000000..2a6ce1320e2 --- /dev/null +++ b/mindspore/lite/src/ops/arithmetic.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Arithmetic::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (inputs_.size() != kDoubleNum) { + MS_LOG(ERROR) << "The number of input must be " << kDoubleNum; + return RET_INPUT_TENSOR_ERROR; + } + if (outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "The number of output must be " << kSingleNum; + return RET_INPUT_TENSOR_ERROR; + } + auto input0 = inputs_[0]; + MS_ASSERT(input0 != nullptr); + auto input1 = inputs_[1]; + MS_ASSERT(input1 != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + auto input_shape0 = input0->shape(); + auto input_shape1 = input1->shape(); + + in_shape0_.resize(5); + in_shape1_.resize(5); + out_shape_.resize(5); + + ndim_ = input_shape0.size(); + if (input_shape0.size() < input_shape1.size()) { + ndim_ = input_shape1.size(); + auto fill_dim_num = input_shape1.size() - input_shape0.size(); + int j = 0; + for (int i = 0; i < input_shape1.size(); i++) { + if (i < fill_dim_num) { + in_shape0_[i] = 1; + } else { + in_shape0_[i] = input_shape0[j++]; + } + in_shape1_[i] = input_shape1[i]; + } + } else if (input_shape0.size() > input_shape1.size()) { + ndim_ = input_shape0.size(); + auto fill_dim_num = input_shape0.size() - input_shape1.size(); + int j = 0; + for (int i = 0; i < input_shape0.size(); i++) { + if (i < fill_dim_num) { + in_shape1_[i] = 1; + } else { + in_shape1_[i] = input_shape1[j++]; + } + in_shape0_[i] = input_shape0[i]; + } + } else { + for (int i = 0; i < input_shape0.size(); i++) { + in_shape1_[i] = input_shape1[i]; + in_shape0_[i] = input_shape0[i]; + } + } + + std::vector output_shape; + for (size_t i = 0; i < ndim_; i++) { + if (in_shape0_[i] != in_shape1_[i]) { + if (in_shape0_[i] == 1) { + out_shape_[i] = in_shape1_[i]; + } else if (in_shape1_[i] == 1) { + out_shape_[i] = in_shape0_[i]; + } else { + MS_LOG(ERROR) << "shapes of input tensors can not be broadCasted"; + return -1; + } + broadcasting_ = true; + } else { + out_shape_[i] = in_shape0_[i]; + } + output_shape.push_back(out_shape_[i]); + } + + output->set_shape(output_shape); + output->set_data_type(input0->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/arithmetic_self.cc b/mindspore/lite/src/ops/arithmetic_self.cc new file mode 100644 index 00000000000..567a190f6a6 --- /dev/null +++ b/mindspore/lite/src/ops/arithmetic_self.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int ArithmeticSelf::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_shape(input->shape()); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/batch_to_space.cc b/mindspore/lite/src/ops/batch_to_space.cc new file mode 100644 index 00000000000..a3ca0b2b49d --- /dev/null +++ b/mindspore/lite/src/ops/batch_to_space.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kBatchToSpaceOutputNum = 1; +constexpr int kBatchToSpaceInputNum = 1; +constexpr int kBlockShapeSize = 2; +constexpr int kCropsSize = 4; +} // namespace + +int BatchToSpace::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (outputs.size() != kBatchToSpaceOutputNum || inputs.size() != kBatchToSpaceInputNum) { + MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); + return RET_PARAM_INVALID; + } + + auto input = inputs.at(0); + if (input->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "batch_to_space only support NHWC now!"; + return RET_FORMAT_ERR; + } + auto input_shape = input->shape(); + if (input_shape.size() != kDimension_4d) { + MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; + return RET_PARAM_INVALID; + } + auto prim = this->primitive->value_as_BatchToSpace(); + auto block_shape = prim->blockShape(); + if (block_shape->size() != kBlockShapeSize) { + MS_LOG(ERROR) << "Block shape size should be " << kBlockShapeSize; + return RET_PARAM_INVALID; + } + auto crops = prim->crops(); + if (crops->size() != kCropsSize) { + MS_LOG(ERROR) << "Crops size should be " << kCropsSize; + return RET_PARAM_INVALID; + } + size_t mul_block_shape = 1; + + for (size_t i = 0; i < kBlockShapeSize; ++i) { + if (block_shape->Get(i) <= 0) { + MS_LOG(ERROR) << "Input block_shape should > 0!"; + return RET_PARAM_INVALID; + } + if (input_shape[kNHWC_n_index] % block_shape->Get(i)) { + MS_LOG(ERROR) << "Dimension n " << input_shape[kNHWC_n_index] << " can not divide block_shape[" << i << "] " + << block_shape->Get(i); + return RET_PARAM_INVALID; + } + mul_block_shape *= block_shape->Get(i); + } + + if (input_shape[kNHWC_n_index] < mul_block_shape) { + MS_LOG(ERROR) << "Dimension n " << input_shape[kNHWC_n_index] << " < product of block shape!"; + return RET_PARAM_INVALID; + } + for (size_t i = 0; i < kCropsSize; ++i) { + if (crops->Get(i) < 0) { + MS_LOG(ERROR) << "Input crops should >= 0"; + return RET_PARAM_INVALID; + } + } + std::vector output_shape(input_shape.size()); + output_shape[kNHWC_n_index] = input_shape[kNHWC_n_index] / mul_block_shape; + output_shape[kNHWC_h_index] = input_shape[kNHWC_h_index] * block_shape->Get(0) - crops->Get(0) - crops->Get(1); + output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] * block_shape->Get(1) - crops->Get(2) - crops->Get(3); + output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index]; + outputs[0]->set_shape(output_shape); + outputs[0]->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/broadcast_to.cc b/mindspore/lite/src/ops/broadcast_to.cc new file mode 100644 index 00000000000..225e34d6147 --- /dev/null +++ b/mindspore/lite/src/ops/broadcast_to.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kBroadcastToInputNum = 1; +constexpr int kBroadcastToOutputNum = 1; +} // namespace + +int BroadcastTo::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (inputs.size() != kBroadcastToInputNum || outputs.size() != kBroadcastToOutputNum) { + MS_LOG(ERROR) << "input size:" << inputs.size() << ", output size:" << outputs.size(); + return RET_PARAM_INVALID; + } + auto input = inputs.at(0); + std::vector dst_shape(this->primitive->value_as_BroadcastTo()->dst_shape()->begin(), + this->primitive->value_as_BroadcastTo()->dst_shape()->end()); + auto input_shape = input->shape(); + std::vector shape(dst_shape.size()); + int input_shape_index = input_shape.size() - 1; + if (input_shape.size() > dst_shape.size()) { + MS_LOG(ERROR) << "input shape size " << input_shape.size() << " should <= broadcast to shape size " + << dst_shape.size() << "!"; + return RET_PARAM_INVALID; + } + + for (int i = dst_shape.size() - 1; i >= 0; --i) { + if (dst_shape[i] < 0) { + MS_LOG(ERROR) << "shape[" << i << "] = " << dst_shape[i] << " ] should be > 0!"; + return RET_PARAM_INVALID; + } + if (input_shape_index >= 0) { + auto dim = input_shape[input_shape_index]; + if (dim != dst_shape[i] && dim != 1) { + MS_LOG(ERROR) << "Invalid broadcast shape!"; + return RET_PARAM_INVALID; + } + } + shape[i] = dst_shape[i]; + --input_shape_index; + } + outputs[0]->set_shape(shape); + outputs[0]->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/cast.cc b/mindspore/lite/src/ops/cast.cc new file mode 100644 index 00000000000..796f80cbee0 --- /dev/null +++ b/mindspore/lite/src/ops/cast.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Cast::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "tensor number is error."; + return RET_INPUT_TENSOR_ERROR; + } + auto cast_prim = this->primitive->value_as_Cast(); + MS_ASSERT(cast_prim != nullptr); + if (input->data_type() != cast_prim->srcT()) { + MS_LOG(ERROR) << "input dataType is error"; + return RET_INPUT_TENSOR_ERROR; + } + if (kSupportDataType.find(input->data_type()) == kSupportDataType.end()) { + MS_LOG(ERROR) << "Unsupport input data type " << input->data_type(); + return RET_INPUT_TENSOR_ERROR; + } + if (cast_prim->dstT() != kNumberTypeFloat || cast_prim->dstT() != kNumberTypeFloat32) { + MS_LOG(ERROR) << "Invalid output datatype " << cast_prim->dstT(); + return RET_INPUT_TENSOR_ERROR; + } + output->set_shape(input->shape()); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/concat.cc b/mindspore/lite/src/ops/concat.cc new file mode 100644 index 00000000000..2d966676d50 --- /dev/null +++ b/mindspore/lite/src/ops/concat.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kConcatOutputNum = 1; +} +int Concat::InferShape(std::vector inputs_, std::vector outputs_) { + if (this->primitive == nullptr) { + MS_LOG(ERROR) << "primitive is nullptr!"; + return RET_PARAM_INVALID; + } + auto input0 = inputs_.front(); + auto output = outputs_.front(); + if (outputs_.size() != kConcatOutputNum) { + MS_LOG(ERROR) << "output size is error"; + return RET_PARAM_INVALID; + } + auto concat_prim = this->primitive->value_as_Concat(); + MS_ASSERT(concat_prim != nullptr); + auto input0_shape = inputs_.at(0)->shape(); + int axis = concat_prim->axis() < 0 ? concat_prim->axis() + input0_shape.size() : concat_prim->axis(); + if (axis < 0 || axis >= input0_shape.size()) { + MS_LOG(ERROR) << "Invalid axis: " << axis; + return RET_PARAM_INVALID; + } + + auto input0_shape_without_axis = input0_shape; + input0_shape_without_axis.erase(input0_shape_without_axis.begin() + axis); + auto input0_data_type = inputs_.at(0)->data_type(); + int output_axis_dim = input0_shape.at(axis); + for (size_t i = 1; i < inputs_.size(); ++i) { + if (inputs_.at(i)->data_type() != input0_data_type) { + MS_LOG(ERROR) << "All inputs should have the same data type!"; + return RET_PARAM_INVALID; + } + + auto shape_tmp = inputs_.at(i)->shape(); + if (shape_tmp.size() != input0_shape.size()) { + MS_LOG(ERROR) << "All inputs should have the same dim num!"; + return RET_PARAM_INVALID; + } + auto axis_tmp = shape_tmp[axis]; + shape_tmp.erase(shape_tmp.begin() + axis); + if (input0_shape_without_axis != shape_tmp) { + MS_LOG(ERROR) << "Inputs should have the same dim except axis!"; + return RET_PARAM_INVALID; + } + output_axis_dim += axis_tmp; + } + auto output_shape = input0_shape; + output_shape[axis] = output_axis_dim; + outputs_[0]->set_shape(output_shape); + output->set_data_type(input0->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/conv.cc b/mindspore/lite/src/ops/conv.cc new file mode 100644 index 00000000000..9597e851fc8 --- /dev/null +++ b/mindspore/lite/src/ops/conv.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output_w) { + MS_ASSERT(this->primitive != nullptr); + auto conv2DPrim = this->primitive->value_as_Conv2D(); + int kernel_w = conv2DPrim->kernelW(); + int kernel_h = conv2DPrim->kernelH(); + int stride_w = conv2DPrim->strideW(); + int stride_h = conv2DPrim->strideH(); + int dilate_w = conv2DPrim->dilateW(); + int dilate_h = conv2DPrim->dilateH(); + pad_l_ = conv2DPrim->padLeft(); + pad_u_ = conv2DPrim->padUp(); + pad_d_ = conv2DPrim->padDown(); + pad_r_ = conv2DPrim->padRight(); + + if (conv2DPrim->padMode() == schema::PadMode_SAME) { + *output_w = std::ceil(static_cast(input_w) / static_cast(stride_w)); + *output_h = std::ceil(static_cast(input_h) / static_cast(stride_h)); + auto pad_h_all = ((*output_h - 1) * stride_h + (kernel_h - 1) * dilate_h + 1 - input_h); + auto pad_w_all = ((*output_w - 1) * stride_w + (kernel_w - 1) * dilate_w + 1 - input_w); + pad_u_ = pad_h_all / 2; + pad_d_ = pad_h_all - pad_u_; + pad_l_ = pad_w_all / 2; + pad_r_ = pad_w_all - pad_l_; + } else { + *output_w = std::ceil((static_cast(input_w) + pad_l_ + pad_r_ - + (static_cast(kernel_w) - 1) * static_cast(dilate_w)) / static_cast(stride_w)); + *output_h = std::ceil((static_cast(input_h) + pad_u_ + pad_d_ - + (static_cast(kernel_h) - 1) * static_cast(dilate_h)) / static_cast(stride_h)); + } +} + +int Conv2D::InferShape(std::vector inputs_, std::vector outputs_) { + if (inputs_.size() != 2 && inputs_.size() != 3) { + MS_LOG(ERROR) << "Add should has two or three inputs"; + return RET_ERROR; + } + if (outputs_.size() != 1) { + MS_LOG(ERROR) << "Add should has one outputs"; + return RET_ERROR; + } + auto *input_tensor = inputs_.front(); + auto *weight_tensor = inputs_.at(1); + auto *out_tensor = outputs_.front(); + MS_ASSERT(input_tensor != nullptr); + MS_ASSERT(out_tensor != nullptr); + + auto in_shape = input_tensor->shape(); + int input_h = in_shape.at(1); + int input_w = in_shape.at(2); + int output_w = 0, output_h = 0; + + this->ConvInferShape(input_h, input_w, &output_h, &output_w); + + std::vector out_shape{input_tensor->shape()}; + out_shape.at(1) = output_h; + out_shape.at(2) = output_w; + out_shape.at(3) = weight_tensor->shape()[0]; + out_tensor->set_shape(out_shape); + out_tensor->SetFormat(input_tensor->GetFormat()); + out_tensor->set_data_type(input_tensor->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/convolution_depthwise.cc b/mindspore/lite/src/ops/convolution_depthwise.cc new file mode 100644 index 00000000000..c26bda7b6f2 --- /dev/null +++ b/mindspore/lite/src/ops/convolution_depthwise.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int DepthwiseConv2D::InferShape(std::vector inputs_, std::vector outputs_) { + if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) { + MS_LOG(ERROR) << "inputs number is invalid"; + return RET_INPUT_TENSOR_ERROR; + } + if (outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "output number is invalid"; + return RET_INPUT_TENSOR_ERROR; + } + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto weight = inputs_.at(1); + MS_ASSERT(weight != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + auto in_shape = input->shape(); + int input_h = in_shape.at(1); + int input_w = in_shape.at(2); + int output_w = 0, output_h = 0; + + auto conv_prim = this->primitive->value_as_DepthwiseConv2D(); + pad_l_ = conv_prim->padLeft(); + pad_u_ = conv_prim->padUp(); + pad_d_ = conv_prim->padDown(); + pad_r_ = conv_prim->padRight(); + if (conv_prim->padMode() == schema::PadMode_SAME) { + output_h = std::ceil(static_cast(input_h) / static_cast(conv_prim->strideH())); + output_w = std::ceil(static_cast(input_w) / static_cast(conv_prim->strideW())); + auto pad_h_all = + ((output_h - 1) * conv_prim->strideH() + (conv_prim->kernelH() - 1) * conv_prim->dilateH() + 1 - input_h); + auto pad_w_all = + ((output_w - 1) * conv_prim->strideW() + (conv_prim->kernelW() - 1) * conv_prim->dilateW() + 1 - input_w); + pad_u_ = pad_h_all / 2; + pad_d_ = pad_h_all - pad_u_; + pad_l_ = pad_w_all / 2; + pad_r_ = pad_w_all - pad_l_; + } else { + output_h = + std::ceil((static_cast(input_h) + pad_u_ + pad_d_ - (static_cast(conv_prim->kernelH()) - 1) * + static_cast(conv_prim->dilateH())) / static_cast(conv_prim->strideH())); + output_w = + std::ceil((static_cast(input_w) + pad_l_ + pad_r_ - (static_cast(conv_prim->kernelW()) - 1) * + static_cast(conv_prim->dilateW())) / static_cast(conv_prim->strideW())); + } + std::vector out_shape{input->shape()}; + out_shape.at(1) = output_h; + out_shape.at(2) = output_w; + out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel + + output->set_shape(out_shape); + output->SetFormat(input->GetFormat()); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/crop.cc b/mindspore/lite/src/ops/crop.cc new file mode 100644 index 00000000000..b58b8a27f44 --- /dev/null +++ b/mindspore/lite/src/ops/crop.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kCropOutputNum = 1; +constexpr int kCropInputNum = 2; +} // namespace + +int Crop::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (outputs.size() != kCropOutputNum || inputs.size() != kCropInputNum) { + MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); + return RET_PARAM_INVALID; + } + outputs[0]->set_shape(inputs[1]->shape()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/deconvolution.cc b/mindspore/lite/src/ops/deconvolution.cc new file mode 100644 index 00000000000..12aa1dcab1f --- /dev/null +++ b/mindspore/lite/src/ops/deconvolution.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int DeConv2D::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto weight = inputs_.at(1); + MS_ASSERT(weight != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + int32_t input_h = input->Height(); + int32_t input_w = input->Width(); + + int32_t output_n = input->Batch(); + int32_t output_h = 0; + int32_t output_w = 0; + int32_t output_c = weight->Batch(); + + auto deconv = GetAttribute(); + int kernel_w = deconv->kernelW(); + int kernel_h = deconv->kernelH(); + int stride_w = deconv->strideW(); + int stride_h = deconv->strideH(); + int dilate_w = deconv->dilateW(); + int dilate_h = deconv->dilateH(); + pad_l_ = deconv->padLeft(); + pad_u_ = deconv->padUp(); + pad_d_ = deconv->padDown(); + pad_r_ = deconv->padRight(); + schema::PadMode pad_mode = deconv->padMode(); + + if (pad_mode == schema::PadMode_CAFFE) { + output_h = (input_h - 1) * stride_h + ((kernel_h - 1) * dilate_h + 1) - pad_u_ - pad_d_; + output_w = (input_w - 1) * stride_w + ((kernel_w - 1) * dilate_w + 1) - pad_l_ - pad_r_; + } else if (pad_mode == schema::PadMode_SAME) { + output_h = (input_h - 1) * stride_h + (kernel_h - 1) * dilate_h + 1 - deconv->padUp() - deconv->padDown(); + output_w = (input_w - 1) * stride_w + (kernel_w - 1) * dilate_w + 1 - deconv->padLeft() - deconv->padRight(); + } else if (pad_mode == schema::PadMode_VALID) { + output_h = input_h * stride_h; + output_w = input_w * stride_w; + } else { + MS_LOG(ERROR) << "unsupported pad mode for deconv"; + } + + std::vector out_shape = {output_n, output_h, output_w, output_c}; + output->set_shape(out_shape); + output->SetFormat(input->GetFormat()); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/deconvolution_depthwise.cc b/mindspore/lite/src/ops/deconvolution_depthwise.cc new file mode 100644 index 00000000000..01e25424824 --- /dev/null +++ b/mindspore/lite/src/ops/deconvolution_depthwise.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int DeconvDepthwiseConv2D::InferShape(std::vector inputs_, std::vector outputs_) { + if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) { + MS_LOG(ERROR) << "inputs number is invalid"; + return RET_INPUT_TENSOR_ERROR; + } + if (outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "output number is invalid"; + return RET_INPUT_TENSOR_ERROR; + } + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto weight = inputs_.at(1); + MS_ASSERT(weight != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + auto in_shape = input->shape(); + int input_h = in_shape.at(1); + int input_w = in_shape.at(2); + int output_w = 0, output_h = 0; + + auto conv_prim = this->primitive->value_as_DeDepthwiseConv2D(); + pad_l_ = conv_prim->padLeft(); + pad_u_ = conv_prim->padUp(); + pad_d_ = conv_prim->padDown(); + pad_r_ = conv_prim->padRight(); + output_h = conv_prim->strideH() * (input_h - 1) * conv_prim->kernelH() - pad_u_ - pad_d_; + output_w = conv_prim->strideW() * (input_w - 1) * conv_prim->kernelW() - pad_l_ - pad_r_; + if ((output_h + conv_prim->padUp() + conv_prim->padDown() - conv_prim->kernelH()) % conv_prim->strideH() != 0) { + output_h += (output_h + conv_prim->padLeft() + conv_prim->padRight() - conv_prim->kernelH()) % conv_prim->strideH(); + } + if ((output_w + conv_prim->padLeft() + conv_prim->padRight() - conv_prim->kernelW()) % conv_prim->strideW() != 0) { + output_w += (output_w + conv_prim->padLeft() + conv_prim->padRight() - conv_prim->kernelW()) % conv_prim->strideW(); + } + std::vector out_shape{input->shape()}; + out_shape.at(1) = output_h; + out_shape.at(2) = output_w; + out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel + + output->set_shape(out_shape); + output->SetFormat(input->GetFormat()); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/depth_to_space.cc b/mindspore/lite/src/ops/depth_to_space.cc new file mode 100644 index 00000000000..025c1ad3603 --- /dev/null +++ b/mindspore/lite/src/ops/depth_to_space.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kDepthToSpaceOutputNum = 1; +constexpr int kDepthToSpaceInputNum = 1; +} + +int DepthToSpace::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (outputs.size() != kDepthToSpaceOutputNum || inputs.size() != kDepthToSpaceInputNum) { + MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); + return RET_PARAM_INVALID; + } + + auto input = inputs.at(0); + if (input->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "depth_to_space only support NHWC now!"; + return RET_FORMAT_ERR; + } + auto input_shape = input->shape(); + if (input_shape.size() != kDimension_4d) { + MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; + return RET_PARAM_INVALID; + } + auto prim = this->primitive->value_as_DepthToSpace(); + int32_t block_size = prim->blockSize(); + if (input_shape[kNHWC_c_index] % (block_size * block_size) != 0 || input_shape[kNHWC_c_index] == 0) { + MS_LOG(ERROR) << "input dimension c size " << input_shape[kNHWC_c_index] << " should be mulitple of block_size(" + << block_size << ") * block_size)!"; + return RET_PARAM_INVALID; + } + std::vector output_shape(input_shape.size()); + output_shape[kNHWC_n_index] = input_shape[kNHWC_n_index]; + output_shape[kNHWC_h_index] = input_shape[kNHWC_h_index] * block_size; + output_shape[kNHWC_w_index] = input_shape[kNHWC_w_index] * block_size; + output_shape[kNHWC_c_index] = input_shape[kNHWC_c_index] / (block_size * block_size); + outputs[0]->set_shape(output_shape); + outputs[0]->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/expand_dims.cc b/mindspore/lite/src/ops/expand_dims.cc new file mode 100644 index 00000000000..588710f886c --- /dev/null +++ b/mindspore/lite/src/ops/expand_dims.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int ExpandDims::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + if (inputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "input size is invalid"; + } + if (outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "output size is invalid"; + } + auto expand_dims_prim = this->primitive->value_as_ExpandDims(); + int dim = expand_dims_prim->dim(); + if (dim < 0) { + dim += input->shape().size() + 1; + } + if (dim > input->shape().size()) { + MS_LOG(ERROR) << "attribute dim out of range"; + return RET_INPUT_TENSOR_ERROR; + } + auto out_shape = input->shape(); + out_shape.insert(out_shape.begin() + dim, 1, 1); + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/fill.cc b/mindspore/lite/src/ops/fill.cc new file mode 100644 index 00000000000..361a5e2b8df --- /dev/null +++ b/mindspore/lite/src/ops/fill.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Fill::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + auto output = outputs_.front(); + if (input == nullptr || output == nullptr) { + MS_LOG(ERROR) << "Fill input or output is null!"; + return RET_ERROR; + } + + if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); + return RET_INPUT_TENSOR_ERROR; + } + auto fill_prim = this->primitive->value_as_Fill(); + if (fill_prim == nullptr) { + MS_LOG(ERROR) << "Fill primitive is null!"; + return RET_ERROR; + } + std::vector output_shape; + (void)output_shape.insert(output_shape.begin(), fill_prim->dims()->begin(), fill_prim->dims()->end()); + output->set_shape(output_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/flatten.cc b/mindspore/lite/src/ops/flatten.cc new file mode 100644 index 00000000000..c2264afcf93 --- /dev/null +++ b/mindspore/lite/src/ops/flatten.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Flatten::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + auto output = outputs_.front(); + if (input == nullptr || output == nullptr) { + MS_LOG(ERROR) << "Flatten input or output is null!"; + return RET_ERROR; + } + + if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); + return RET_INPUT_TENSOR_ERROR; + } + + auto input_shape = input->shape(); + std::vector output_shape(2); + output_shape[0] = input_shape[0]; + output_shape[1] = 1; + for (int i = 1; i < input_shape.size(); i++) { + output_shape[1] *= input_shape[i]; + } + output->set_shape(output_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/fullconnection.cc b/mindspore/lite/src/ops/fullconnection.cc new file mode 100644 index 00000000000..0b44faecfd8 --- /dev/null +++ b/mindspore/lite/src/ops/fullconnection.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int FullConnection::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input0 = inputs_.front(); + MS_ASSERT(input0 != nullptr); + auto input1 = inputs_.at(1); + MS_ASSERT(input1 != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + auto fc_prim = this->primitive->value_as_FullConnection(); + if ((fc_prim->hasBias() && inputs_.size() != kMultiNum) || (!fc_prim->hasBias() && inputs_.size() != kDoubleNum)) { + MS_LOG(ERROR) << "Input tensors num error"; + return RET_INPUT_TENSOR_ERROR; + } + if (fc_prim->axis() < 1 || fc_prim->axis() > input0->shape().size()) { + MS_LOG(ERROR) << "FullConnection axis invalid"; + return RET_INPUT_TENSOR_ERROR; + } + int new_k = 1; + for (size_t i = fc_prim->axis(); i < input0->shape().size(); ++i) { + new_k *= input0->shape().at(i); + } + if (new_k != input1->shape().at(1)) { + MS_LOG(ERROR) << "Input1 size invalid"; + return RET_PARAM_INVALID; + } + if (fc_prim->hasBias()) { + if (inputs_.at(2)->shape()[0] != input1->shape()[0]) { + MS_LOG(ERROR) << "bias size invalid"; + return RET_PARAM_INVALID; + } + } + std::vector out_shape{inputs_[0]->shape()}; + out_shape.resize(fc_prim->axis() + 1); + out_shape[fc_prim->axis()] = input1->shape()[0]; + output->set_shape(out_shape); + output->set_data_type(input0->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/gather.cc b/mindspore/lite/src/ops/gather.cc new file mode 100644 index 00000000000..0e5cd619186 --- /dev/null +++ b/mindspore/lite/src/ops/gather.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Gather::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (inputs_.size() != kDoubleNum) { + MS_LOG(ERROR) << "Gather should have two inputs"; + return RET_INPUT_TENSOR_ERROR; + } + if (outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "Gather should have one outputs"; + return RET_INPUT_TENSOR_ERROR; + } + + auto input = inputs_.at(0); + MS_ASSERT(input != nullptr); + auto indices = inputs_.at(1); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(input != nullptr); + + auto gather_prim = this->primitive->value_as_Gather(); + MS_ASSERT(gather_prim != nullptr); + + int axis = gather_prim->axis(); + int batch_dims = gather_prim->batchDims(); + if (axis < 0) { + axis += input->shape().size(); + } + auto indices_shape = indices->shape(); + int indices_rank = indices_shape.size(); + if (indices_rank < batch_dims + 1) { + MS_LOG(ERROR) << "input[1]'s rank is less than batchDim + 1"; + return RET_ERROR; + } + if (batch_dims != 0) { + MS_LOG(ERROR) << "batchDims " << batch_dims << " != 0, which is not support"; + return RET_ERROR; + } + auto in_shape = input->shape(); + int in_rank = in_shape.size(); + if (in_rank < axis + 1) { + MS_LOG(ERROR) << "input[0]'s rank is less than axis + 1"; + return RET_ERROR; + } + + std::vector out_shape{in_shape}; + out_shape.erase(out_shape.begin() + axis); + for (size_t i = 0; i < indices_rank; i++) { + out_shape.insert(out_shape.begin() + axis, indices_shape[i]); + } + + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/gather_nd.cc b/mindspore/lite/src/ops/gather_nd.cc new file mode 100644 index 00000000000..4f5598817bf --- /dev/null +++ b/mindspore/lite/src/ops/gather_nd.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int GatherNd::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (inputs_.size() != kDoubleNum) { + MS_LOG(ERROR) << "GatherNd should have two inputs"; + return RET_INPUT_TENSOR_ERROR; + } + if (outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "GatherNd should have one outputs"; + return RET_INPUT_TENSOR_ERROR; + } + + auto input = inputs_.at(0); + MS_ASSERT(input != nullptr); + auto indices = inputs_.at(1); + MS_ASSERT(indices != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + auto in_shape = input->shape(); + int in_rank = in_shape.size(); + auto indices_shape = indices->shape(); + int indices_rank = indices_shape.size(); + + if (indices_shape[indices_rank - 1] > in_rank) { + MS_LOG(ERROR) << "Input of indices data is error!"; + return RET_ERROR; + } + + std::vector out_shape; + int i = 0; + for (i = 0; i < indices_rank - 1; ++i) { + out_shape.emplace_back(indices_shape[i]); + } + for (i = indices_shape[indices_rank - 1]; i < in_rank; ++i) { + out_shape.emplace_back(in_shape[i]); + } + + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/matmul.cc b/mindspore/lite/src/ops/matmul.cc new file mode 100644 index 00000000000..d7cb772f41b --- /dev/null +++ b/mindspore/lite/src/ops/matmul.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int MatMul::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (inputs_.size() != kDoubleNum) { + MS_LOG(ERROR) << "OpMatMul inputs size: " << inputs_.size(); + return RET_INPUT_TENSOR_ERROR; + } + auto input0 = inputs_.front(); + MS_ASSERT(input0 != nullptr); + auto input1 = inputs_.at(1); + MS_ASSERT(input1 != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + std::vector x_shape = input0->shape(); + std::vector w_shape = input1->shape(); + if (x_shape.size() < 2 || w_shape.size() < 2) { + MS_LOG(ERROR) << "inputs shape is invalid"; + return RET_INPUT_TENSOR_ERROR; + } + + auto matmul_prim = this->primitive->value_as_MatMul(); + if (matmul_prim->transposeA()) { + int tmp = x_shape.back(); + x_shape[x_shape.size() - 1] = x_shape[x_shape.size() - 2]; + x_shape[x_shape.size() - 2] = tmp; + } + if (matmul_prim->transposeB()) { + int tmp = w_shape.back(); + w_shape[w_shape.size() - 1] = w_shape[w_shape.size() - 2]; + w_shape[w_shape.size() - 2] = tmp; + } + auto y_shape_size = std::max(x_shape.size(), w_shape.size()); + std::vector y_shape(y_shape_size); + y_shape = x_shape; + y_shape[y_shape_size - 1] = w_shape[w_shape.size() - 1]; + output->set_shape(y_shape); + output->set_data_type(input0->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/nchw2nhwc.cc b/mindspore/lite/src/ops/nchw2nhwc.cc new file mode 100644 index 00000000000..bd5f27b86a7 --- /dev/null +++ b/mindspore/lite/src/ops/nchw2nhwc.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" +#include "src/common/common.h" + +namespace mindspore::lite { +int Nchw2Nhwc::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + std::vector nchw_shape = input->shape(); + std::vector nhwc_shape{nchw_shape}; + nhwc_shape[NHWC_N] = nchw_shape[NCHW_N]; + nhwc_shape[NHWC_H] = nchw_shape[NCHW_H]; + nhwc_shape[NHWC_W] = nchw_shape[NCHW_W]; + nhwc_shape[NHWC_C] = nchw_shape[NCHW_C]; + output->set_shape(nhwc_shape); + output->SetFormat(schema::Format_NHWC); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/nhwc2nchw.cc b/mindspore/lite/src/ops/nhwc2nchw.cc new file mode 100644 index 00000000000..049b1a4a187 --- /dev/null +++ b/mindspore/lite/src/ops/nhwc2nchw.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" +#include "src/common/common.h" + +namespace mindspore::lite { +int Nhwc2Nchw::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + std::vector nhwc_shape = input->shape(); + std::vector nchw_shape{nhwc_shape}; + nchw_shape[NCHW_N] = nhwc_shape[NHWC_N]; + nchw_shape[NCHW_C] = nhwc_shape[NHWC_C]; + nchw_shape[NCHW_H] = nhwc_shape[NHWC_H]; + nchw_shape[NCHW_W] = nhwc_shape[NHWC_W]; + output->set_shape(nchw_shape); + output->SetFormat(schema::Format_NCHW); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/one_hot.cc b/mindspore/lite/src/ops/one_hot.cc new file mode 100644 index 00000000000..eb96edad891 --- /dev/null +++ b/mindspore/lite/src/ops/one_hot.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr size_t kOneHotInputNum = 4; +} +int OneHot::InferShape(std::vector inputs, std::vector outputs) { + if (this->primitive == nullptr) { + return RET_NULL_PTR; + } + auto one_hot_prim = this->primitive->value_as_OneHot(); + if (one_hot_prim == nullptr) { + return RET_NULL_PTR; + } + int axis = one_hot_prim->axis(); + + // indices, depth, on_value, off_value + if (inputs.size() != kOneHotInputNum) { + MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum; + return RET_ERROR; + } + auto depth_tensor = inputs.at(1); + if (depth_tensor == nullptr) { + return RET_NULL_PTR; + } + const int *depth = static_cast(depth_tensor->Data()); + + auto input = inputs.front(); + if (input == nullptr) { + return RET_NULL_PTR; + } + const auto input_shape = input->shape(); + int input_rank = static_cast(input_shape.size()); + if (axis < 0) { + axis += input_rank + 1; + } + std::vector output_shape(input_shape); + output_shape.insert(output_shape.cbegin() + axis, *depth); + + auto output = outputs.front(); + if (output == nullptr) { + return RET_NULL_PTR; + } + output->set_shape(output_shape); + + auto on_value = inputs.at(2); + if (on_value == nullptr) { + return RET_NULL_PTR; + } + output->set_data_type(on_value->data_type()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/ops.cc b/mindspore/lite/src/ops/ops.cc new file mode 100644 index 00000000000..642b2898fbb --- /dev/null +++ b/mindspore/lite/src/ops/ops.cc @@ -0,0 +1,144 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include +#include "utils/log_adapter.h" +#include "include/errorcode.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +Primitive *Primitive::CreatePrimitive(schema::Primitive *primitive) { + MS_ASSERT(primitive != nullptr); + auto op_type = primitive->value_type(); + switch (op_type) { + case schema::PrimitiveType_SoftMax: + return new lite::SoftMax(const_cast(primitive)); + case schema::PrimitiveType_Activation: + return new lite::Activation(const_cast(primitive)); + case schema::PrimitiveType_Conv2D: + return new lite::Conv2D(const_cast(primitive)); + case schema::PrimitiveType_Reduce: + return new lite::Reduce(const_cast(primitive)); + case schema::PrimitiveType_Pooling: + return new lite::Pooling(const_cast(primitive)); + case schema::PrimitiveType_DepthwiseConv2D: + return new lite::DepthwiseConv2D(const_cast(primitive)); + case schema::PrimitiveType_FusedBatchNorm: + return new lite::FusedBatchNorm(const_cast(primitive)); + case schema::PrimitiveType_CaffeBatchNorm: + return new lite::CaffeBatchNorm(const_cast(primitive)); + case schema::PrimitiveType_FullConnection: + return new lite::FullConnection(const_cast(primitive)); + case schema::PrimitiveType_Power: + return new lite::Power(const_cast(primitive)); + case schema::PrimitiveType_Range: + return new lite::Range(const_cast(primitive)); + case schema::PrimitiveType_Mul: + return new lite::Mul(const_cast(primitive)); + case schema::PrimitiveType_Add: + return new lite::Add(const_cast(primitive)); + case schema::PrimitiveType_Sub: + return new lite::Sub(const_cast(primitive)); + case schema::PrimitiveType_Div: + return new lite::Div(const_cast(primitive)); + case schema::PrimitiveType_BiasAdd: + return new lite::BiasAdd(const_cast(primitive)); + case schema::PrimitiveType_ExpandDims: + return new lite::ExpandDims(const_cast(primitive)); + case schema::PrimitiveType_ArgMax: + return new lite::ArgMax(const_cast(primitive)); + case schema::PrimitiveType_ArgMin: + return new lite::ArgMin(const_cast(primitive)); + case schema::PrimitiveType_Cast: + return new lite::Cast(const_cast(primitive)); + case schema::PrimitiveType_Reshape: + return new lite::Reshape(const_cast(primitive)); + case schema::PrimitiveType_Eltwise: + return new lite::Eltwise(const_cast(primitive)); + case schema::PrimitiveType_Ceil: + return new lite::Ceil(const_cast(primitive)); + case schema::PrimitiveType_Concat: + return new lite::Concat(const_cast(primitive)); + case schema::PrimitiveType_Fill: + return new lite::Fill(const_cast(primitive)); + case schema::PrimitiveType_Nhwc2Nchw: + return new lite::Nhwc2Nchw(const_cast(primitive)); + case schema::PrimitiveType_Nchw2Nhwc: + return new lite::Nchw2Nhwc(const_cast(primitive)); + case schema::PrimitiveType_Transpose: + return new lite::Transpose(const_cast(primitive)); + case schema::PrimitiveType_Squeeze: + return new lite::Squeeze(const_cast(primitive)); + case schema::PrimitiveType_SquaredDifference: + return new lite::SquaredDifference(const_cast(primitive)); + case schema::PrimitiveType_Split: + return new lite::Split(const_cast(primitive)); + case schema::PrimitiveType_FloorDiv: + return new lite::FloorDiv(const_cast(primitive)); + case schema::PrimitiveType_FloorMod: + return new lite::FloorMod(const_cast(primitive)); + case schema::PrimitiveType_Reverse: + return new lite::Reverse(const_cast(primitive)); + case schema::PrimitiveType_Scale: + return new lite::Scale(const_cast(primitive)); + case schema::PrimitiveType_GatherNd: + return new lite::GatherNd(const_cast(primitive)); + case schema::PrimitiveType_Tile: + return new lite::Tile(const_cast(primitive)); + case schema::PrimitiveType_TopK: + return new lite::TopK(const_cast(primitive)); + case schema::PrimitiveType_Unique: + return new lite::Unique(const_cast(primitive)); + case schema::PrimitiveType_Unstack: + return new lite::Unstack(const_cast(primitive)); + case schema::PrimitiveType_ReverseSequence: + return new lite::ReverseSequence(const_cast(primitive)); + case schema::PrimitiveType_Round: + return new lite::Round(const_cast(primitive)); + case schema::PrimitiveType_ZerosLike: + return new lite::ZerosLike(const_cast(primitive)); + case schema::PrimitiveType_Where: + return new lite::Where(const_cast(primitive)); + case schema::PrimitiveType_Floor: + return new lite::Floor(const_cast(primitive)); + case schema::PrimitiveType_Shape: + return new lite::Shape(const_cast(primitive)); + case schema::PrimitiveType_ScatterND: + return new lite::ScatterND(const_cast(primitive)); + case schema::PrimitiveType_Unsqueeze: + return new lite::Unsqueeze(const_cast(primitive)); + case schema::PrimitiveType_Flatten: + return new lite::Flatten(const_cast(primitive)); + case schema::PrimitiveType_StridedSlice: + return new lite::StridedSlice(const_cast(primitive)); + default: + break; + } + return nullptr; +} + +int Primitive::InferShape(std::vector inputs_, std::vector outputs_) { + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_shape(input->shape()); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/ops.h b/mindspore/lite/src/ops/ops.h new file mode 100644 index 00000000000..def9c3ee724 --- /dev/null +++ b/mindspore/lite/src/ops/ops.h @@ -0,0 +1,666 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPS_OPS_H_ +#define MINDSPORE_LITE_SRC_OPS_OPS_H_ + +#include +#include +#include +#include "schema/model_generated.h" +#include "ir/dtype/type_id.h" + +namespace mindspore { +namespace lite::tensor { +class Tensor; +} +namespace lite { +constexpr uint32_t kSingleNum = 1; +constexpr uint32_t kDoubleNum = 2; +constexpr uint32_t kMultiNum = 3; +constexpr uint32_t kNHWC_n_index = 0; +constexpr uint32_t kNHWC_h_index = 1; +constexpr uint32_t kNHWC_w_index = 2; +constexpr uint32_t kNHWC_c_index = 3; +constexpr uint32_t kDimension_4d = 4; + +const std::set kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32}; + +class Primitive { + public: + explicit Primitive(schema::Primitive *primitive) : primitive(primitive) {} + static Primitive *CreatePrimitive(schema::Primitive *primitive); + virtual ~Primitive() {} + const schema::Primitive *Value() const { return this->primitive; } + schema::PrimitiveType Type() const { return this->primitive->value_type(); } + const void *Attribute() const { return this->primitive->value(); } + virtual int InferShape(std::vector inputs_, std::vector outputs_); + + protected: + schema::Primitive *primitive; +}; + +class Conv2D : public Primitive { + public: + explicit Conv2D(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Conv2D *GetAttribute() const { return this->primitive->value_as_Conv2D(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; + int PadUp() const { return this->pad_u_; } + int PadDown() const { return this->pad_d_; } + int PadLeft() const { return this->pad_l_; } + int PadRight() const { return this->pad_r_; } + + protected: + void ConvInferShape(int input_h, int input_w, int *output_h, int *output_w); + + protected: + int pad_u_ = 0; + int pad_d_ = 0; + int pad_l_ = 0; + int pad_r_ = 0; +}; + +class Pooling : public Primitive { + public: + explicit Pooling(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Pooling *GetAttribute() const { return this->primitive->value_as_Pooling(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; + int PadUp() const { return this->pad_u_; } + int PadDown() const { return this->pad_d_; } + int PadLeft() const { return this->pad_l_; } + int PadRight() const { return this->pad_r_; } + + protected: + int pad_u_ = 0; + int pad_d_ = 0; + int pad_l_ = 0; + int pad_r_ = 0; +}; + +class CaffeBatchNorm : public Primitive { + public: + explicit CaffeBatchNorm(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::CaffeBatchNorm *GetAttribute() const { return this->primitive->value_as_CaffeBatchNorm(); } +}; + +class FusedBatchNorm : public Primitive { + public: + explicit FusedBatchNorm(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::FusedBatchNorm *GetAttribute() const { return this->primitive->value_as_FusedBatchNorm(); } +}; + +class Activation : public Primitive { + public: + explicit Activation(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Activation *GetAttribute() const { return this->primitive->value_as_Activation(); } +}; + +class Split : public Primitive { + public: + explicit Split(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Split *GetAttribute() const { return this->primitive->value_as_Split(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Reshape : public Primitive { + public: + explicit Reshape(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Reshape *GetAttribute() const { return this->primitive->value_as_Reshape(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; + + private: + int CalNewShape(const tensor::Tensor *in_tensor, std::vector *out_shape) const; +}; + +class FullConnection : public Primitive { + public: + explicit FullConnection(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::FullConnection *GetAttribute() const { return this->primitive->value_as_FullConnection(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class SoftMax : public Primitive { + public: + explicit SoftMax(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::SoftMax *GetAttribute() const { return this->primitive->value_as_SoftMax(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Reduce : public Primitive { + public: + explicit Reduce(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Reduce *GetAttribute() const { return this->primitive->value_as_Reduce(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class DepthwiseConv2D : public Primitive { + public: + explicit DepthwiseConv2D(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::DepthwiseConv2D *GetAttribute() const { return this->primitive->value_as_DepthwiseConv2D(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; + int PadUp() const { return this->pad_u_; } + int PadDown() const { return this->pad_d_; } + int PadLeft() const { return this->pad_l_; } + int PadRight() const { return this->pad_r_; } + + protected: + int pad_u_ = 0; + int pad_d_ = 0; + int pad_l_ = 0; + int pad_r_ = 0; +}; + +class DeConv2D : public Primitive { + public: + explicit DeConv2D(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::DeConv2D *GetAttribute() const { return this->primitive->value_as_DeConv2D(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; + int PadUp() const { return this->pad_u_; } + int PadDown() const { return this->pad_d_; } + int PadLeft() const { return this->pad_l_; } + int PadRight() const { return this->pad_r_; } + + protected: + int pad_u_ = 0; + int pad_d_ = 0; + int pad_l_ = 0; + int pad_r_ = 0; +}; + +class DeconvDepthwiseConv2D : public Primitive { + public: + explicit DeconvDepthwiseConv2D(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::DeDepthwiseConv2D *GetAttribute() const { return this->primitive->value_as_DeDepthwiseConv2D(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; + int PadUp() const { return this->pad_u_; } + int PadDown() const { return this->pad_d_; } + int PadLeft() const { return this->pad_l_; } + int PadRight() const { return this->pad_r_; } + + protected: + int pad_u_ = 0; + int pad_d_ = 0; + int pad_l_ = 0; + int pad_r_ = 0; +}; + +class Power : public Primitive { + public: + explicit Power(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Power *GetAttribute() const { return this->primitive->value_as_Power(); } +}; + +class Range : public Primitive { + public: + explicit Range(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Range *GetAttribute() const { return this->primitive->value_as_Range(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class AddN : public Primitive { + public: + explicit AddN(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::AddN *GetAttribute() const { return this->primitive->value_as_AddN(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Arithmetic : public Primitive { + public: + explicit Arithmetic(schema::Primitive *primitive) : Primitive(primitive) {} + int InferShape(std::vector inputs_, std::vector outputs_) override; + bool Broadcasting() { return this->broadcasting_; } + int NDims() { return this->ndim_; } + std::vector InShape0() { return this->in_shape0_; } + std::vector InShape1() { return this->in_shape1_; } + std::vector OutputShape() { return this->out_shape_; } + + protected: + bool broadcasting_ = false; + int ndim_; + std::vector in_shape0_; + std::vector in_shape1_; + std::vector out_shape_; +}; + +class Add : public Arithmetic { + public: + explicit Add(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Add *GetAttribute() const { return this->primitive->value_as_Add(); } +}; + +class Mul : public Arithmetic { + public: + explicit Mul(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Mul *GetAttribute() const { return this->primitive->value_as_Mul(); } +}; + +class Sub : public Arithmetic { + public: + explicit Sub(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Sub *GetAttribute() const { return this->primitive->value_as_Sub(); } +}; + +class Div : public Arithmetic { + public: + explicit Div(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Div *GetAttribute() const { return this->primitive->value_as_Div(); } +}; + +class SquaredDifference : public Arithmetic { + public: + explicit SquaredDifference(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::SquaredDifference *GetAttribute() const { return this->primitive->value_as_SquaredDifference(); } +}; + +class Eltwise : public Arithmetic { + public: + explicit Eltwise(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Eltwise *GetAttribute() const { return this->primitive->value_as_Eltwise(); } +}; + +class ArithmeticSelf : public Primitive { + public: + explicit ArithmeticSelf(schema::Primitive *primitive) : Primitive(primitive) {} + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Abs : public ArithmeticSelf { + public: + explicit Abs(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Abs *GetAttribute() const { return this->primitive->value_as_Abs(); } +}; + +class Cos : public ArithmeticSelf { + public: + explicit Cos(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Cos *GetAttribute() const { return this->primitive->value_as_Cos(); } +}; + +class Exp : public ArithmeticSelf { + public: + explicit Exp(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Exp *GetAttribute() const { return this->primitive->value_as_Exp(); } +}; + +class Log : public ArithmeticSelf { + public: + explicit Log(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Log *GetAttribute() const { return this->primitive->value_as_Log(); } +}; + +class Square : public ArithmeticSelf { + public: + explicit Square(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Square *GetAttribute() const { return this->primitive->value_as_Square(); } +}; + +class Sqrt : public ArithmeticSelf { + public: + explicit Sqrt(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Sqrt *GetAttribute() const { return this->primitive->value_as_Sqrt(); } +}; + +class Rsqrt : public ArithmeticSelf { + public: + explicit Rsqrt(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Rsqrt *GetAttribute() const { return this->primitive->value_as_Rsqrt(); } +}; + +class Sin : public ArithmeticSelf { + public: + explicit Sin(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Sin *GetAttribute() const { return this->primitive->value_as_Sin(); } +}; + +class LogicalNot : public ArithmeticSelf { + public: + explicit LogicalNot(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::LogicalNot *GetAttribute() const { return this->primitive->value_as_LogicalNot(); } +}; + +class RealDiv : public Arithmetic { + public: + explicit RealDiv(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::RealDiv *GetAttribute() const { return this->primitive->value_as_RealDiv(); } +}; + +class BiasAdd : public Primitive { + public: + explicit BiasAdd(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::BiasAdd *GetAttribute() const { return this->primitive->value_as_BiasAdd(); } +}; + +class ExpandDims : public Primitive { + public: + explicit ExpandDims(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::ExpandDims *GetAttribute() const { return this->primitive->value_as_ExpandDims(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Unsqueeze : public Primitive { + public: + explicit Unsqueeze(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Unsqueeze *GetAttribute() const { return this->primitive->value_as_Unsqueeze(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Cast : public Primitive { + public: + explicit Cast(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Cast *GetAttribute() const { return this->primitive->value_as_Cast(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Ceil : public ArithmeticSelf { + public: + explicit Ceil(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Ceil *GetAttribute() const { return this->primitive->value_as_Ceil(); } +}; + +class Concat : public Primitive { + public: + explicit Concat(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Concat *GetAttribute() const { return this->primitive->value_as_Concat(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Fill : public Primitive { + public: + explicit Fill(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Fill *GetAttribute() const { return this->primitive->value_as_Fill(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class ArgMax : public Primitive { + public: + explicit ArgMax(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::ArgMax *GetAttribute() const { return this->primitive->value_as_ArgMax(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class ArgMin : public Primitive { + public: + explicit ArgMin(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::ArgMin *GetAttribute() const { return this->primitive->value_as_ArgMin(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class MatMul : public Primitive { + public: + explicit MatMul(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::MatMul *GetAttribute() const { return this->primitive->value_as_MatMul(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Nchw2Nhwc : public Primitive { + public: + explicit Nchw2Nhwc(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Nchw2Nhwc *GetAttribute() const { return this->primitive->value_as_Nchw2Nhwc(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Nhwc2Nchw : public Primitive { + public: + explicit Nhwc2Nchw(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Nhwc2Nchw *GetAttribute() const { return this->primitive->value_as_Nhwc2Nchw(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Rank : public Primitive { + public: + explicit Rank(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Rank *GetAttribute() const { return this->primitive->value_as_Rank(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Pad : public Primitive { + public: + explicit Pad(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Pad *GetAttribute() const { return this->primitive->value_as_Pad(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Gather : public Primitive { + public: + explicit Gather(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Gather *GatherAttribute() const { return this->primitive->value_as_Gather(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class GatherNd : public Primitive { + public: + explicit GatherNd(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::GatherNd *GetAttribute() const { return this->primitive->value_as_GatherNd(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Slice : public Primitive { + public: + explicit Slice(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Slice *GetAttribute() const { return this->primitive->value_as_Slice(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class BroadcastTo : public Primitive { + public: + explicit BroadcastTo(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::BroadcastTo *GetAttribute() const { return this->primitive->value_as_BroadcastTo(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Squeeze : public Primitive { + public: + explicit Squeeze(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Squeeze *SqueezeAttribute() const { return this->primitive->value_as_Squeeze(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Floor : public ArithmeticSelf { + public: + explicit Floor(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Floor *GetAttribute() const { return this->primitive->value_as_Floor(); } +}; + +class FloorDiv : public Arithmetic { + public: + explicit FloorDiv(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Sub *GetAttribute() const { return this->primitive->value_as_Sub(); } +}; + +class FloorMod : public Arithmetic { + public: + explicit FloorMod(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Sub *GetAttribute() const { return this->primitive->value_as_Sub(); } +}; + +class Transpose : public Primitive { + public: + explicit Transpose(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Transpose *GetAttribute() const { return this->primitive->value_as_Transpose(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class LocalResponseNormalization : public Primitive { + public: + explicit LocalResponseNormalization(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::LocalResponseNormalization *GetAttribute() const { + return this->primitive->value_as_LocalResponseNormalization(); + } +}; + +class Tile : public Primitive { + public: + explicit Tile(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Tile *GetAttribute() const { return this->primitive->value_as_Tile(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Reverse : public Primitive { + public: + explicit Reverse(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Reverse *GetAttribute() const { return this->primitive->value_as_Reverse(); } +}; + +class TopK : public Primitive { + public: + explicit TopK(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::TopK *GetAttribute() const { return this->primitive->value_as_TopK(); } + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; + +class Scale : public Primitive { + public: + explicit Scale(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Scale *GetAttribute() const { return this->primitive->value_as_Scale(); } +}; + +class Stack : public Primitive { + public: + explicit Stack(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Stack *GetAttribute() const { return this->primitive->value_as_Stack(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Unstack : public Primitive { + public: + explicit Unstack(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Unstack *GetAttribute() const { return this->primitive->value_as_Unstack(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Unique : public Primitive { + public: + explicit Unique(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Unique *GetAttribute() const { return this->primitive->value_as_Unique(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class ReverseSequence : public Primitive { + public: + explicit ReverseSequence(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::ReverseSequence *GetAttribute() const { return this->primitive->value_as_ReverseSequence(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class DepthToSpace : public Primitive { + public: + explicit DepthToSpace(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::DepthToSpace *GetAttribute() const { return this->primitive->value_as_DepthToSpace(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Resize : public Primitive { + public: + explicit Resize(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Resize *GetAttrbute() const { return this->primitive->value_as_Resize(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Round : public ArithmeticSelf { + public: + explicit Round(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Round *GetAttribute() const { return this->primitive->value_as_Round(); } +}; + +class ZerosLike : public Primitive { + public: + explicit ZerosLike(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::ZerosLike *GetAttribute() const { return this->primitive->value_as_ZerosLike(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Where : public Primitive { + public: + explicit Where(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Where *GetAttribute() const { return this->primitive->value_as_Where(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class BatchToSpace : public Primitive { + public: + explicit BatchToSpace(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::BatchToSpace *GetAttribute() const { return this->primitive->value_as_BatchToSpace(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Crop : public Primitive { + public: + explicit Crop(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Crop *GetAttribute() const { return this->primitive->value_as_Crop(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Shape : public Primitive { + public: + explicit Shape(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Shape *GetAttribute() const { return this->primitive->value_as_Shape(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class ScatterND : public Primitive { + public: + explicit ScatterND(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::ScatterND *GetAttribute() const { return this->primitive->value_as_ScatterND(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class Flatten : public Primitive { + public: + explicit Flatten(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Flatten *GetAttribute() const { return this->primitive->value_as_Flatten(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class OneHot : public Primitive { + public: + explicit OneHot(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::OneHot *GetAttribute() const { return this->primitive->value_as_OneHot(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; + +class StridedSlice : public Primitive { + public: + explicit StridedSlice(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::StridedSlice *GetAttribute() const { return this->primitive->value_as_StridedSlice(); } + int InferShape(std::vector inputs, std::vector outputs) override; + int NDims() { return this->updated_ndim_; } + void ApplyNewAxisMask(); + std::vector ApplyShrinkMask(std::vector out_shape); + void ApplyBeginMask(); + void ApplyEndMask(); + void ApplyEllipsisMask(); + std::vector UpdatedInShape() { return this->updated_in_shape_; } + std::vector UpdatedBegins() { return this->updated_begins_; } + std::vector UpdatedEnds() { return this->updated_ends_; } + std::vector UpdatedStrides() { return this->updated_strides_; } + + protected: + int updated_ndim_; + int ori_ndim_; + std::vector updated_in_shape_; + std::vector updated_begins_; + std::vector updated_ends_; + std::vector updated_strides_; + std::vector begins_mask_; + std::vector ends_mask_; + std::vector ellipsis_mask_; + std::vector new_axis_mask_; + std::vector shrink_axis_mask_; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_OPS_OPS_H_ + diff --git a/mindspore/lite/src/ops/pad.cc b/mindspore/lite/src/ops/pad.cc new file mode 100644 index 00000000000..8604da24e3f --- /dev/null +++ b/mindspore/lite/src/ops/pad.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +const size_t kPaddingsSize = 8; +const size_t kInputRank = 4; +} // namespace +int Pad::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (this->primitive == nullptr) { + return RET_NULL_PTR; + } + auto pad_prim = this->primitive->value_as_Pad(); + if (pad_prim == nullptr) { + return RET_NULL_PTR; + } + auto paddings = pad_prim->paddings(); + if (paddings == nullptr) { + return RET_NULL_PTR; + } + MS_ASSERT(paddings->size() == kPaddingsSize); + + auto input = inputs.front(); + if (input == nullptr) { + return RET_NULL_PTR; + } + auto input_shape = input->shape(); + MS_ASSERT(input_shape.size() == kInputRank); + std::vector output_shape; + for (size_t i = 0; i < input_shape.size(); i++) { + auto shape = input_shape[i] + (*paddings)[2 * i] + (*paddings)[2 * i + 1]; + output_shape.push_back(shape); + } + + auto output = outputs.front(); + if (output == nullptr) { + return RET_NULL_PTR; + } + output->set_shape(output_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/pooling.cc b/mindspore/lite/src/ops/pooling.cc new file mode 100644 index 00000000000..c25a558bbce --- /dev/null +++ b/mindspore/lite/src/ops/pooling.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Pooling::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + int input_h = input->shape().at(1); + int input_w = input->shape().at(2); + + auto pooling_prim = this->primitive->value_as_Pooling(); + MS_ASSERT(pooling_prim != nullptr); + auto window_h = pooling_prim->windowH(); + auto window_w = pooling_prim->windowW(); + if (pooling_prim->global()) { + window_h = input_h; + window_w = input_w; + } + + int output_h = 0; + int output_w = 0; + pad_l_ = pooling_prim->padLeft(); + pad_u_ = pooling_prim->padUp(); + pad_d_ = pooling_prim->padDown(); + pad_r_ = pooling_prim->padRight(); + if (pooling_prim->padMode() == schema::PadMode_SAME) { + output_w = std::ceil(static_cast(input_w) / static_cast(pooling_prim->strideW())); + output_h = std::ceil(static_cast(input_h) / static_cast(pooling_prim->strideH())); + auto pad_h_all = ((output_h - 1) * pooling_prim->strideH() + (window_h - 1) + 1 - input_h); + auto pad_w_all = ((output_w - 1) * pooling_prim->strideW() + (window_w - 1) + 1 - input_w); + pad_u_ = pad_h_all / 2; + pad_d_ = pad_h_all - pad_u_; + pad_l_ = pad_w_all / 2; + pad_r_ = pad_w_all - pad_l_; + } else { + auto round_mode = pooling_prim->roundMode(); + if (round_mode == schema::RoundMode_FLOOR) { + output_h = std::floor((input_h + pad_u_ + pad_d_ - window_h) / pooling_prim->strideH() + 1); + output_w = std::floor((input_w + pad_l_ + pad_r_ - window_w) / pooling_prim->strideW() + 1); + } else if (round_mode == schema::RoundMode_CEIL) { + output_h = + std::ceil((input_h + pooling_prim->padUp() + pooling_prim->padDown() - window_h) / pooling_prim->strideH() + 1); + output_w = std::ceil( + (input_w + pooling_prim->padLeft() + pooling_prim->padRight() - window_w) / pooling_prim->strideW() + 1); + } else { + MS_LOG(ERROR) << "unsupported round mode."; + } + } + + // todo: fmk type + auto input_shape = input->shape(); + input_shape.at(1) = output_h; + input_shape.at(2) = output_w; + output->set_shape(input_shape); + output->set_data_type(input->data_type()); + // todo: temp fix + output->SetFormat(schema::Format_NHWC); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/range.cc b/mindspore/lite/src/ops/range.cc new file mode 100644 index 00000000000..4adafa26898 --- /dev/null +++ b/mindspore/lite/src/ops/range.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Range::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + auto range_prim = this->primitive->value_as_Range(); + MS_ASSERT(range_prim != nullptr); + + int shape_size = std::ceil(static_cast(range_prim->limit() - range_prim->start()) / range_prim->delta()); + std::vector in_shape(1); + in_shape.push_back(shape_size); + output->set_shape(in_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/rank.cc b/mindspore/lite/src/ops/rank.cc new file mode 100644 index 00000000000..c7b70930b14 --- /dev/null +++ b/mindspore/lite/src/ops/rank.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Rank::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + std::vector in_shape(1, 1); + output->set_shape(in_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/reduce.cc b/mindspore/lite/src/ops/reduce.cc new file mode 100644 index 00000000000..888a61df875 --- /dev/null +++ b/mindspore/lite/src/ops/reduce.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr size_t kInputSize = 1; +constexpr size_t kOutputSize = 1; +} // namespace +int Reduce::InferShape(std::vector inputs_, std::vector outputs_) { + if (inputs_.size() != kInputSize || outputs_.size() != kOutputSize) { + return RET_ERROR; + } + auto input = inputs_.front(); + auto output = outputs_.front(); + if (input == nullptr || output == nullptr) { + return RET_NULL_PTR; + } + if (this->primitive == nullptr) { + return RET_NULL_PTR; + } + auto reduce_prim = this->primitive->value_as_Reduce(); + bool keep_dims = static_cast(reduce_prim->keepDims()); + std::vector in_shape = input->shape(); + std::vector out_shape; + const auto &axes = reduce_prim->axes(); + auto num_axes = axes->size(); + // reduce on all axes + if (num_axes == 0) { + if (keep_dims) { + for (auto i = 0; i < in_shape.size(); i++) { + out_shape.push_back(1); + } + } + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + return RET_OK; + } + + // reduce on selected axes + for (size_t i = 0; i < in_shape.size(); i++) { + bool reduce_axis = false; + for (int idx = 0; idx < num_axes; ++idx) { + if (static_cast((*axes)[idx]) == i) { + reduce_axis = true; + break; + } + } + if (reduce_axis) { + if (keep_dims) { + out_shape.push_back(1); + } + } else { + out_shape.push_back(in_shape[i]); + } + } + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/reshape.cc b/mindspore/lite/src/ops/reshape.cc new file mode 100644 index 00000000000..3e2a9c6eef8 --- /dev/null +++ b/mindspore/lite/src/ops/reshape.cc @@ -0,0 +1,120 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Reshape::CalNewShape(const tensor::Tensor *in_tensor, std::vector *out_shape) const { + size_t in_shape_size = 1; + for (size_t i = 0; i < in_tensor->shape().size(); i++) { + in_shape_size *= in_tensor->shape()[i]; + } + + int64_t inferIndex = -1; + size_t out_shapeSize = 1; + for (size_t i = 0; i < out_shape->size(); i++) { + if (out_shape->at(i) == -1) { + if (inferIndex == -1) { + inferIndex = i; + } else { + MS_LOG(ERROR) << "output shape should has no more than one dim which need infer"; + return RET_ERROR; + } + } else if (out_shape->at(i) < 0) { + MS_LOG(ERROR) << "output shape dim should be non-negative"; + return RET_ERROR; + } else if (out_shape->at(i) == 0) { + out_shape->at(i) = in_tensor->shape().at(i); + out_shapeSize *= out_shape->at(i); + } else { + out_shapeSize *= out_shape->at(i); + } + } + + if (inferIndex == -1 && out_shapeSize != in_shape_size) { + MS_LOG(ERROR) << "output shapeSize: " << out_shapeSize << " should be equal to input shapeSize: " << in_shape_size; + return RET_ERROR; + } + if (inferIndex != -1) { + out_shape->at(inferIndex) = in_shape_size / out_shapeSize; + } + return RET_OK; +} + +int Reshape::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + auto reshape_prim = this->primitive->value_as_Reshape(); + MS_ASSERT(reshape_prim != nullptr); + + std::vector out_shape; + if (inputs_.size() == kDoubleNum) { + auto shape_tensor = inputs_.at(1); + size_t shape_size = shape_tensor->shape().size(); + switch (shape_tensor->data_type()) { + case kNumberTypeInt8: { + auto data = reinterpret_cast(shape_tensor->Data()); + for (size_t i = 0; i < shape_size; i++) { + out_shape.push_back(data[i]); + } + } break; + case kNumberTypeInt32: { + auto data = reinterpret_cast(shape_tensor->Data()); + for (size_t i = 0; i < shape_size; i++) { + out_shape.push_back(data[i]); + } + } break; + case kNumberTypeFloat: { + auto data = reinterpret_cast(shape_tensor->Data()); + for (size_t i = 0; i < shape_size; i++) { + out_shape.push_back(data[i]); + } + } break; + case kNumberTypeUInt32: { + auto data = reinterpret_cast(shape_tensor->Data()); + for (size_t i = 0; i < shape_size; i++) { + out_shape.push_back(data[i]); + } + } break; + default: { + MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type(); + return RET_ERROR; + } + } + } else if (inputs_.size() == kSingleNum) { + std::copy(reshape_prim->shape()->begin(), reshape_prim->shape()->end(), std::back_inserter(out_shape)); + } else { + MS_LOG(ERROR) << "inputs tensor size invalid."; + } + + auto ret = CalNewShape(inputs_.front(), &out_shape); + if (ret != RET_OK) { + MS_LOG(ERROR) << "CalNewShape error"; + return ret; + } + + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/resize.cc b/mindspore/lite/src/ops/resize.cc new file mode 100644 index 00000000000..7dd387c6369 --- /dev/null +++ b/mindspore/lite/src/ops/resize.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" + +namespace mindspore::lite { +namespace { +constexpr int kInputRank = 4; +} // namespace +int Resize::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + if (input == nullptr) { + return RET_NULL_PTR; + } + auto output = outputs_.front(); + if (output == nullptr) { + return RET_NULL_PTR; + } + auto resize = GetAttrbute(); + auto new_height = resize->newHeight(); + auto new_width = resize->newWidth(); + + std::vector output_shape; + output_shape.push_back(input->Batch()); + output_shape.push_back(new_height); + output_shape.push_back(new_width); + output_shape.push_back(input->Channel()); + output->set_shape(output_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/reverse_sequence.cc b/mindspore/lite/src/ops/reverse_sequence.cc new file mode 100644 index 00000000000..e7ff1e7e625 --- /dev/null +++ b/mindspore/lite/src/ops/reverse_sequence.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int ReverseSequence::InferShape(std::vector inputs, std::vector outputs) { + auto input = inputs.front(); + auto output = outputs.front(); + MS_ASSERT(input != nullptr); + MS_ASSERT(output != nullptr); + + output->set_shape(input->shape()); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/scatter_nd.cc b/mindspore/lite/src/ops/scatter_nd.cc new file mode 100644 index 00000000000..cf9f4dfbc35 --- /dev/null +++ b/mindspore/lite/src/ops/scatter_nd.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kScatterNDInputNum = 3; +constexpr int kScatterNDOutputNum = 1; +constexpr int kScatterShapeIndex = 0; +constexpr int kScatterIndicesIndex = 1; +constexpr int kScatterUpdateIndex = 2; +} // namespace + +int ScatterND::InferShape(std::vector inputs_, std::vector outputs_) { + if (inputs_.size() != kScatterNDInputNum) { + MS_LOG(ERROR) << "inputs number is not equal to " << kScatterNDInputNum; + return RET_ERROR; + } + if (outputs_.size() != kScatterNDOutputNum) { + MS_LOG(ERROR) << "outputs number is not equal to " << kScatterNDInputNum; + return RET_ERROR; + } + auto shape = inputs_.at(kScatterShapeIndex); + if (shape == nullptr) { + MS_LOG(ERROR) << "shape null pointer dereferencing."; + return RET_ERROR; + } + auto indices = inputs_.at(kScatterIndicesIndex); + if (indices == nullptr) { + MS_LOG(ERROR) << "indices null pointer dereferencing."; + return RET_ERROR; + } + auto update = inputs_.at(kScatterUpdateIndex); + if (update == nullptr) { + MS_LOG(ERROR) << "update null pointer dereferencing."; + return RET_ERROR; + } + auto output = outputs_.front(); + auto shape_data = reinterpret_cast(shape->Data()); + std::vector out_shape(shape_data, shape_data + sizeof(shape_data) / sizeof(shape_data[0])); + output->set_shape(out_shape); + output->set_data_type(update->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/shape.cc b/mindspore/lite/src/ops/shape.cc new file mode 100644 index 00000000000..8ca3416d78b --- /dev/null +++ b/mindspore/lite/src/ops/shape.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kShapeInputNum = 1; +constexpr int kShapeOutputNum = 1; + +} // namespace +int Shape::InferShape(std::vector inputs_, std::vector outputs_) { + if (inputs_.size() != kShapeInputNum) { + MS_LOG(ERROR) << "inputs to Shape operator should be 1, but " << inputs_.size() << " is given."; + return RET_ERROR; + } + if (outputs_.size() != kShapeOutputNum) { + MS_LOG(ERROR) << "outputs to Shape operator should be 1, but " << outputs_.size() << " is given."; + return RET_ERROR; + } + + auto in_tensor = inputs_.front(); + auto out_tensor = outputs_.front(); + std::vector out_shape; + out_shape.push_back(static_cast(in_tensor->shape().size())); + + auto ret_shape = out_tensor->set_shape(out_shape); + if (ret_shape != 1 || size_t(out_tensor->shape()[0]) != in_tensor->shape().size()) { + MS_LOG(ERROR) << "Set shape fails."; + return RET_ERROR; + } + auto ret_dtype = out_tensor->set_data_type({in_tensor->data_type()}); + if (ret_dtype != in_tensor->data_type()) { + MS_LOG(ERROR) << "Set datatype fails."; + return RET_ERROR; + } + + // todo + // auto ret_data = out_tensor->MallocData(); + // if (ret_data != 0) { + // MS_LOG(ERROR) << "Allocate memory fails."; + // return RET_ERROR; + // } + + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/slice.cc b/mindspore/lite/src/ops/slice.cc new file mode 100644 index 00000000000..994fabb1cc5 --- /dev/null +++ b/mindspore/lite/src/ops/slice.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kSliceInputNum = 1; +constexpr int kSliceOutputNum = 1; +} + +int Slice::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (inputs.size() != kSliceInputNum || outputs.size() != kSliceOutputNum) { + MS_LOG(ERROR) << "input size:" << inputs.size() << ",output size:" << outputs.size(); + return RET_PARAM_INVALID; + } + auto input = inputs.at(0); + auto input_shape = input->shape(); + auto slice_prim = this->primitive->value_as_Slice(); + std::vector slice_begin(slice_prim->begin()->begin(), slice_prim->begin()->end()); + std::vector slice_size(slice_prim->size()->begin(), slice_prim->size()->end()); + std::vector output_shape(input_shape.size()); + for (int i = 0; i < input_shape.size(); ++i) { + if (slice_size[i] < 0 && slice_size[i] != -1) { + MS_LOG(ERROR) << "Invalid size input!size[" << i << "]=" << slice_size[i]; + return RET_PARAM_INVALID; + } + if (slice_begin[i] < 0) { + MS_LOG(ERROR) << "Invalid begin input " << slice_begin[i] << " which should be >= 0"; + return RET_PARAM_INVALID; + } + if (input_shape[i] <= slice_begin[i]) { + MS_LOG(ERROR) << "Invalid begin input!begin[" << i << "]=" << slice_begin[i] << " which should be <= " + << input_shape[i]; + return RET_PARAM_INVALID; + } + if (slice_size[i] > (input_shape[i] - slice_begin[i])) { + MS_LOG(ERROR) << "Invalid size input " << slice_size[i] << " which should be <= " + << input_shape[i] - slice_begin[i]; + return RET_PARAM_INVALID; + } + + output_shape[i] = slice_size[i] < 0 ? input_shape[i] - slice_begin[i] : slice_size[i]; + } + + outputs[0]->set_shape(output_shape); + outputs[0]->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/softmax.cc b/mindspore/lite/src/ops/softmax.cc new file mode 100644 index 00000000000..a0c46ce67d5 --- /dev/null +++ b/mindspore/lite/src/ops/softmax.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int SoftMax::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_shape(input->shape()); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/split.cc b/mindspore/lite/src/ops/split.cc new file mode 100644 index 00000000000..a7104b01d9f --- /dev/null +++ b/mindspore/lite/src/ops/split.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kSplitInputNum = 1; +} // namespace +int Split::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto spilt_prim = this->primitive->value_as_Split(); + MS_ASSERT(spilt_prim != nullptr); + if (inputs_.size() != kSplitInputNum) { + MS_LOG(ERROR) << "inputs number is not equal to " << kSplitInputNum; + return RET_ERROR; + } + auto output = outputs_.front(); + if (output == nullptr) { + MS_LOG(ERROR) << "output null pointer dereferencing."; + return RET_ERROR; + } + int number_split = spilt_prim->numberSplit(); + if (outputs_.size() != number_split) { + MS_LOG(ERROR) << "outputs number is not equal to " << number_split; + return RET_ERROR; + } + int split_dim = spilt_prim->splitDim(); + std::vector input_shape = input->shape(); + std::vector size_split; + size_split.insert(size_split.begin(), spilt_prim->sizeSplits()->begin(), spilt_prim->sizeSplits()->end()); + + for (int i = 0; i < number_split; ++i) { + std::vector output_shape; + output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end()); + auto split_dim_i = size_split.empty() ? input_shape[split_dim] / number_split : size_split[i]; + output_shape[split_dim] = split_dim_i; + outputs_[i]->set_shape(output_shape); + outputs_[i]->set_data_type(input->data_type()); + } + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/squeeze.cc b/mindspore/lite/src/ops/squeeze.cc new file mode 100644 index 00000000000..d2c0b2841f0 --- /dev/null +++ b/mindspore/lite/src/ops/squeeze.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kSqueezeInputNum = 1; +constexpr int kSqueezeOutputNum = 1; +} +int Squeeze::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (kSqueezeInputNum != inputs_.size()) { + MS_LOG(ERROR) << "Add should has " << kSqueezeInputNum << " inputs"; + return -1; + } + if (kSqueezeOutputNum != outputs_.size()) { + MS_LOG(ERROR) << "Add should has " << kSqueezeOutputNum << " outputs"; + return -1; + } + auto *in_tensor = inputs_.front(); + auto in_shape = in_tensor->shape(); + std::vector out_shape; + + // todo: getAxis + auto squeeze_prim = this->primitive->value_as_Squeeze(); + MS_EXCEPTION_IF_NULL(squeeze_prim); + auto axis = squeeze_prim->axis(); + std::vector axes_; + for (auto iter = axis->begin(); iter != axis->end(); iter++) { + axes_.push_back(*iter); +} + + if (axes_.size() == 0) { + for (int i = 0; i < in_shape.size(); i++) { + if (in_shape[i] != 1) { + out_shape.push_back(in_shape[i]); + } + } + } else { + int axisIdx = 0; + for (int i = 0; i < in_shape.size(); i++) { + if (axisIdx < axes_.size() && axes_[axisIdx] == i) { + MS_ASSERT(in_shape[i] == 1); + axisIdx++; + continue; + } else { + out_shape.push_back(in_shape[i]); + } + } + } + + outputs_.front()->set_shape(out_shape); + outputs_.front()->set_data_type(in_tensor->data_type()); + + return 0; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/stack.cc b/mindspore/lite/src/ops/stack.cc new file mode 100644 index 00000000000..e0bffc2a9ba --- /dev/null +++ b/mindspore/lite/src/ops/stack.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kStackOutputNum = 1; +constexpr int kStackMinInputNum = 2; +} + +int Stack::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (outputs.size() != kStackOutputNum) { + MS_LOG(ERROR) << "Invalid output size:" << outputs.size(); + return RET_PARAM_INVALID; + } + if (inputs.size() < kStackMinInputNum) { + MS_LOG(ERROR) << "Invalid input size " << inputs.size(); + return RET_PARAM_INVALID; + } + auto input = inputs.at(0); + auto input_shape = input->shape(); + auto stack_prim = this->primitive->value_as_Stack(); + std::vector output_shape = input_shape; + int axis = stack_prim->axis() < 0 ? stack_prim->axis() + input_shape.size() : stack_prim->axis(); + if (axis < 0 || axis > input_shape.size()) { + MS_LOG(ERROR) << "Invalid axis " << stack_prim->axis(); + return RET_PARAM_INVALID; + } + for (size_t i = 1; i < inputs.size(); ++i) { + auto input_shape_tmp = inputs[i]->shape(); + if (input_shape_tmp.size() != input_shape.size()) { + MS_LOG(ERROR) << "All input shape size should be the same!"; + return RET_PARAM_INVALID; + } + for (size_t j = 0; j < input_shape.size(); ++j) { + if (input_shape_tmp[j] != input_shape[j]) { + MS_LOG(ERROR) << "All input shape should be the same!"; + return RET_PARAM_INVALID; + } + } + } + + output_shape.insert(output_shape.begin() + axis, inputs.size()); + outputs[0]->set_shape(output_shape); + outputs[0]->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/strided_slice.cc b/mindspore/lite/src/ops/strided_slice.cc new file mode 100644 index 00000000000..0fcd008ab8b --- /dev/null +++ b/mindspore/lite/src/ops/strided_slice.cc @@ -0,0 +1,162 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "include/errorcode.h" +#include "src/ops/ops.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +namespace { +constexpr int kStridedSliceOutputNum = 1; +constexpr int kStridedSliceInputNum = 1; +} // namespace + +void StridedSlice::ApplyNewAxisMask() { + for (int i = 0; i < new_axis_mask_.size(); i++) { + if (new_axis_mask_.at(i)) { + updated_ndim_ += 1; + updated_in_shape_.insert(updated_in_shape_.begin() + i, 1); + updated_begins_.at(i) = 0; + updated_ends_.at(i) = 1; + updated_strides_.at(i) = 1; + updated_begins_.emplace_back(0); + updated_ends_.emplace_back(updated_in_shape_.at(updated_ndim_ - 1)); + updated_strides_.emplace_back(1); + } + } +} + +std::vector StridedSlice::ApplyShrinkMask(std::vector out_shape) { + auto old_out_shape = out_shape; + out_shape.clear(); + for (int i = 0; i < shrink_axis_mask_.size(); i++) { + if (shrink_axis_mask_.at(i)) { + updated_ends_.at(i) = updated_begins_.at(i) + 1; + updated_strides_.at(i) = 1; + } else { + out_shape.emplace_back(old_out_shape.at(i)); + } + } + for (int i = shrink_axis_mask_.size(); i < old_out_shape.size(); i++) { + out_shape.emplace_back(old_out_shape.at(i)); + } + return out_shape; +} + +/*only one bit will be used if multiple bits are true.*/ +void StridedSlice::ApplyEllipsisMask() { + for (int i = 0; i < ellipsis_mask_.size(); i++) { + if (ellipsis_mask_.at(i)) { + updated_begins_.at(i) = 0; + updated_ends_.at(i) = updated_in_shape_.at(i); + break; + } + } +} + +void StridedSlice::ApplyBeginMask() { + for (int i = 0; i < ori_ndim_; i++) { + updated_begins_.at(i) = 0; + } +} + +void StridedSlice::ApplyEndMask() { + for (int i = 0; i < ori_ndim_; i++) { + updated_ends_.at(i) = 0; + } +} + +int StridedSlice::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive != nullptr); + if (outputs.size() != kStridedSliceOutputNum) { + MS_LOG(ERROR) << "Invalid output size:" << outputs.size(); + return RET_PARAM_INVALID; + } + if (inputs.size() < kStridedSliceInputNum) { + MS_LOG(ERROR) << "Invalid input size " << inputs.size(); + return RET_PARAM_INVALID; + } + auto input = inputs.at(0); + MS_ASSERT(input != nullptr); + auto input_shape = input->shape(); + std::vector output_shape; + auto strided_slice_prim = this->primitive->value_as_StridedSlice(); + updated_ndim_ = static_cast(strided_slice_prim->begin()->size()); + ori_ndim_ = updated_ndim_; + MS_ASSERT(updated_ndim_ == static_cast(strided_slice_prim->end()->size())); + MS_ASSERT(updated_ndim_ == static_cast(strided_slice_prim->stride()->size())); + MS_ASSERT(updated_ndim_ == static_cast(input_shape.size())); + + for (int i = 0; i < updated_ndim_; i++) { + updated_in_shape_.emplace_back(input_shape.at(i)); + updated_begins_.emplace_back((*(strided_slice_prim->begin()))[i]); + updated_ends_.emplace_back((*(strided_slice_prim->end()))[i]); + updated_strides_.emplace_back((*(strided_slice_prim->stride()))[i]); + } + + // set all mask to original input shape + begins_mask_.resize(updated_ndim_); + ends_mask_.resize(updated_ndim_); + ellipsis_mask_.resize(updated_ndim_); + new_axis_mask_.resize(updated_ndim_); + shrink_axis_mask_.resize(updated_ndim_); + + // convert bit to vector + for (int i = 0; i < updated_ndim_; i++) { + begins_mask_.at(i) = static_cast(strided_slice_prim->beginMask()) & (1 << i); + ends_mask_.at(i) = static_cast(strided_slice_prim->endMask()) & (1 << i); + ellipsis_mask_.at(i) = static_cast(strided_slice_prim->ellipsisMask()) & (1 << i); + new_axis_mask_.at(i) = static_cast(strided_slice_prim->newAxisMask()) & (1 << i); + shrink_axis_mask_.at(i) = static_cast(strided_slice_prim->shrinkAxisMask()) & (1 << i); + } + + ApplyNewAxisMask(); + ApplyNewAxisMask(); + ApplyEndMask(); + ApplyEllipsisMask(); + + output_shape.resize(updated_in_shape_.size()); + for (int i = 0; i < updated_in_shape_.size(); i++) { + if (i < ori_ndim_ && new_axis_mask_.at(i)) { + output_shape.at(i) = 1; + } else { + // begins and ends out of range handling + if (updated_begins_.at(i) >= updated_in_shape_.at(i) || updated_begins_.at(i) < -updated_in_shape_.at(i) || + updated_ends_.at(i) < -updated_in_shape_.at(i) || updated_ends_.at(i) > updated_in_shape_.at(i)) { + return RET_PARAM_INVALID; + } + updated_begins_.at(i) = updated_begins_.at(i) % updated_in_shape_.at(i); + updated_ends_.at(i) = updated_ends_.at(i) % updated_in_shape_.at(i); + + if ((updated_ends_.at(i) <= updated_begins_.at(i) && updated_strides_.at(i) > 0) || + (updated_ends_.at(i) >= updated_begins_.at(i) && updated_strides_.at(i) < 0)) { + output_shape.at(i) = 0; + } else { + output_shape.at(i) = 1 + (updated_ends_.at(i) - updated_begins_.at(i) - 1) / updated_strides_.at(i); + } + } + } + + output_shape = ApplyShrinkMask(output_shape); + + outputs.front()->set_shape(output_shape); + outputs.front()->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/tile.cc b/mindspore/lite/src/ops/tile.cc new file mode 100644 index 00000000000..875d68ba50e --- /dev/null +++ b/mindspore/lite/src/ops/tile.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Tile::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + auto tile_prim = this->primitive->value_as_Tile(); + MS_ASSERT(tile_prim != nullptr); + + std::vector out_shape; + std::vector multiples; + std::copy(tile_prim->multiples()->begin(), tile_prim->multiples()->end(), std::back_inserter(multiples)); + for (size_t i = 0; i < input->shape().size(); ++i) { + int tmp = input->shape()[i] * multiples[i]; + out_shape.push_back(tmp); + } + + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/topk.cc b/mindspore/lite/src/ops/topk.cc new file mode 100644 index 00000000000..a915be9d6cc --- /dev/null +++ b/mindspore/lite/src/ops/topk.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int TopK::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (inputs_.size() != kSingleNum || outputs_.size() != kDoubleNum) { + MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); + return RET_INPUT_TENSOR_ERROR; + } + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output0 = outputs_.front(); + MS_ASSERT(output0 != nullptr); + auto output1 = outputs_.at(1); + MS_ASSERT(output1 != nullptr); + auto topk_prim = this->primitive->value_as_TopK(); + MS_ASSERT(topk_prim != nullptr); + + output0->set_shape(input->shape()); + output0->set_data_type(input->data_type()); +// output0->shape().back() = topk_prim->k(); + + output1->set_shape(input->shape()); + output1->set_data_type(input->data_type()); +// output1->shape().back() = topk_prim->k(); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/transpose.cc b/mindspore/lite/src/ops/transpose.cc new file mode 100644 index 00000000000..7165175f772 --- /dev/null +++ b/mindspore/lite/src/ops/transpose.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Transpose::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + MS_ASSERT(inputs_.size() == kSingleNum); + MS_ASSERT(outputs_.size() == kSingleNum); + auto transpore_prim = this->primitive->value_as_Transpose(); + int conjugate = transpore_prim->conjugate(); + if (conjugate) { + MS_LOG(ERROR) << "Transpose conjugate is not support currently"; + return RET_ERROR; + } + std::vector perm; + perm.insert(perm.begin(), transpore_prim->perm()->begin(), transpore_prim->perm()->end()); + + std::vector in_shape = input->shape(); + std::vector out_shape; + out_shape.resize(perm.size()); + for (int i = 0; i < perm.size(); ++i) { + out_shape[i] = in_shape[perm[i]]; + } + + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/unique.cc b/mindspore/lite/src/ops/unique.cc new file mode 100644 index 00000000000..c0405d3423f --- /dev/null +++ b/mindspore/lite/src/ops/unique.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Unique::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (inputs_.size() != kSingleNum || outputs_.size() != kDoubleNum) { + MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); + return RET_INPUT_TENSOR_ERROR; + } + auto &input = inputs_.at(0); + MS_ASSERT(input != nullptr); + auto &output0 = outputs_.at(0); + MS_ASSERT(output0 != nullptr); + auto &output1 = outputs_.at(1); + MS_ASSERT(output1 != nullptr); + output0->set_shape(input->shape()); + output0->set_data_type(input->data_type()); + output1->set_shape(input->shape()); + output1->set_data_type(kNumberTypeInt32); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/unsqueeze.cc b/mindspore/lite/src/ops/unsqueeze.cc new file mode 100644 index 00000000000..a93b5f871fc --- /dev/null +++ b/mindspore/lite/src/ops/unsqueeze.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Unsqueeze::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + if (inputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "input size is invalid"; + } + if (outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "output size is invalid"; + } + auto unsqueeze_prim = this->primitive->value_as_Unsqueeze(); + auto dims = unsqueeze_prim->axis()->data(); + auto in_shape = input->shape(); + auto in_rank = in_shape.size(); + auto dim_rank = unsqueeze_prim->axis()->size(); + std::vector out_shape; + + if (dim_rank == 0) { + for (auto d : in_shape) { + if (d != 1) { + out_shape.push_back(d); + } + } + } else { + auto sz = in_rank + dim_rank; + int in_itr = 0; + int ax_itr = 0; + for (int i = 0; i < sz; i++) { + if (ax_itr < dim_rank && dims[ax_itr] == i) { + out_shape.emplace_back(1); + ax_itr++; + } else if (ax_itr < dim_rank && dims[ax_itr] + sz == i) { + out_shape.emplace_back(1); + ax_itr++; + } else { + if (in_shape[in_itr] > 1) { + out_shape.emplace_back(in_shape[in_itr]); + } + in_itr++; + } + } + } + + output->set_shape(out_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/unstack.cc b/mindspore/lite/src/ops/unstack.cc new file mode 100644 index 00000000000..8df560a047b --- /dev/null +++ b/mindspore/lite/src/ops/unstack.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Unstack::InferShape(std::vector inputs, std::vector outputs) { + auto input = inputs.at(0); + MS_ASSERT(input != nullptr); + auto input_shape = input->shape(); + auto prim = this->primitive->value_as_Unstack(); + int axis = prim->axis() < 0 ? prim->axis() + input_shape.size() : prim->axis(); + if (axis < 0 || axis >= input_shape.size()) { + MS_LOG(ERROR) << "Invalid axis " << prim->axis(); + return RET_PARAM_INVALID; + } + + std::vector output_shape; + for (size_t i = 0; i < input_shape.size(); ++i) { + if (i != axis) { + output_shape.push_back(input_shape.at(i)); + } + } + for (auto &out : outputs) { + MS_ASSERT(out != nullptr); + out->set_shape(output_shape); + out->set_data_type(input->data_type()); + } + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/where.cc b/mindspore/lite/src/ops/where.cc new file mode 100644 index 00000000000..e6657c504e6 --- /dev/null +++ b/mindspore/lite/src/ops/where.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int Where::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "where input or output number invalid, Input size:" << inputs_.size() + << ", output size: " << outputs_.size(); + return RET_INPUT_TENSOR_ERROR; + } + if (inputs_.size() < 3) { + MS_LOG(ERROR) << "Input shape tensors should b"; + return RET_INPUT_TENSOR_ERROR; + } + auto input0 = inputs_.at(0); + auto input1 = inputs_.at(1); + auto input2 = inputs_.at(2); + int num = input0->ElementsNum(); + int num1 = input1->ElementsNum(); + int num2 = input2->ElementsNum(); + int nummax = num > num1 ? num : (num1 > num2 ? num1 : num2); + + auto shape_tmp = inputs_.at(0)->shape(); + auto shape_tmp1 = inputs_.at(1)->shape(); + auto shape_tmp2 = inputs_.at(2)->shape(); + int axisout = 0; + int temp = 0; + for (int j = 0; j < shape_tmp.size(); j++) { + if (shape_tmp[j] == shape_tmp1[j] && shape_tmp[j] != shape_tmp2[j]) { + axisout = j; + break; + } + if (shape_tmp[j] == shape_tmp2[j] && shape_tmp[j] != shape_tmp1[j]) { + axisout = j; + break; + } + if (shape_tmp1[j] == shape_tmp2[j] && shape_tmp[j] != shape_tmp1[j]) { + axisout = j; + break; + } + temp += 1; + if (temp == shape_tmp.size()) { + outputs_[0]->set_shape(shape_tmp); + output->set_data_type(input->data_type()); + return RET_OK; + } + } + + auto output_shape = shape_tmp; + output_shape[axisout] = nummax; + outputs_[0]->set_shape(output_shape); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/ops/zeroslike.cc b/mindspore/lite/src/ops/zeroslike.cc new file mode 100644 index 00000000000..e5068bed526 --- /dev/null +++ b/mindspore/lite/src/ops/zeroslike.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +int ZerosLike::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { + MS_LOG(ERROR) << "zeroslike input or output number invalid, Input size:" << inputs_.size() + << ", output size: " << outputs_.size(); + return RET_INPUT_TENSOR_ERROR; + } + output->set_shape(input->shape()); + output->set_data_type(input->data_type()); + return RET_OK; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/param_value_lite.h b/mindspore/lite/src/param_value_lite.h new file mode 100644 index 00000000000..e387cb273ce --- /dev/null +++ b/mindspore/lite/src/param_value_lite.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PARAM_VALUE_LITE_H_ +#define MINDSPORE_LITE_SRC_PARAM_VALUE_LITE_H_ + +#include +#include +#include +#include + +#include "ir/param_value.h" +#include "ir/dtype/type_id.h" + +namespace mindspore { +struct AnfQuantParam { + double scale; + int32_t zeroPoint; + double min; + double max; + bool narrowRange; + bool inited; + int32_t numBits; + AnfQuantParam() : scale(1.0), zeroPoint(0), min(0.0), max(0.0), narrowRange(false), numBits(8), inited(false) {} +}; +class ParamValueLite : public ParamValue { + public: + ParamValueLite() : tensor_addr_(nullptr), tensor_size_(0) {} + virtual ~ParamValueLite() = default; + + size_t tensor_size() const { return tensor_size_; } + void set_tensor_size(size_t size) { tensor_size_ = size; } + // todo + void *tensor_addr() const { return tensor_addr_; } + void set_tensor_addr(void *addr) { tensor_addr_ = addr; } + + std::vector tensor_shape() const { return tensor_shape_; } + void set_tensor_shape(std::vector tensor_shape) { tensor_shape_ = std::move(tensor_shape); } + + TypeId tensor_type() const { return type_id_; } + void set_tensor_type(TypeId type_id) { type_id_ = type_id; } + + int tensor_shape_size() const { + int size = 1; + for (auto val : tensor_shape_) { + size *= val; + } + return size; + } + std::vector> &quant_param() { return quant_params_; } + void set_quant_param(std::unique_ptr &quant_param) { + quant_params_.emplace_back(std::move(quant_param)); + } + + private: + void *tensor_addr_; + size_t tensor_size_; + std::vector tensor_shape_; + TypeId type_id_; + std::vector> quant_params_; +}; + +using ParamValueLitePtr = std::shared_ptr; +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_PARAM_VALUE_LITE_H_ + diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc new file mode 100644 index 00000000000..e00ddf76743 --- /dev/null +++ b/mindspore/lite/src/populate_parameter.cc @@ -0,0 +1,1036 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/populate_parameter.h" +#include +#include "src/ops/ops.h" +#include "utils/log_adapter.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/fp32/arg_min_max.h" +#include "src/runtime/kernel/arm/opclib/fp32/cast.h" +#include "src/runtime/kernel/arm/opclib/concat_parameter.h" +#include "src/runtime/kernel/arm/opclib/fp32/slice.h" +#include "src/runtime/kernel/arm/opclib/fp32/broadcast_to.h" +#include "src/runtime/kernel/arm/opclib/reshape_parameter.h" +#include "src/runtime/kernel/arm/opclib/fp32/stack.h" +#include "src/runtime/kernel/arm/opclib/unstack.h" +#include "src/runtime/kernel/arm/opclib/fp32/depth_to_space.h" +#include "src/runtime/kernel/arm/opclib/conv_parameter.h" +#include "src/runtime/kernel/arm/opclib/fp32/pooling.h" +#include "src/runtime/kernel/arm/opclib/matmul.h" +#include "src/runtime/kernel/arm/opclib/fp32/softmax.h" +#include "src/runtime/kernel/arm/opclib/tile.h" +#include "src/runtime/kernel/arm/opclib/topk.h" +#include "src/runtime/kernel/arm/opclib/fp32/reduce.h" +#include "src/runtime/kernel/arm/opclib/fp32/activation.h" +#include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h" +#include "src/runtime/kernel/arm/opclib/fused_batchnorm.h" +#include "src/runtime/kernel/arm/opclib/power.h" +#include "src/runtime/kernel/arm/opclib/fp32/range.h" +#include "src/runtime/kernel/arm/opclib/fp32/local_response_norm.h" +#include "src/runtime/kernel/arm/opclib/fp32/expandDims.h" +#include "src/runtime/kernel/arm/opclib/fp32/arithmetic_self.h" +#include "src/runtime/kernel/arm/opclib/pad.h" +#include "src/runtime/kernel/arm/opclib/fp32/fill.h" +#include "src/runtime/kernel/arm/opclib/transpose.h" +#include "src/runtime/kernel/arm/opclib/split.h" +#include "src/runtime/kernel/arm/opclib/squeeze.h" +#include "src/runtime/kernel/arm/opclib/fp32/gather.h" +#include "src/runtime/kernel/arm/opclib/fp32/reverse.h" +#include "src/runtime/kernel/arm/opclib/reverse_sequence.h" +#include "src/runtime/kernel/arm/opclib/unique.h" +#include "src/runtime/kernel/arm/opclib/scale.h" +#include "src/runtime/kernel/arm/opclib/fp32/gatherNd.h" +#include "src/runtime/kernel/arm/opclib/resize.h" +#include "src/runtime/kernel/arm/opclib/scatter_nd.h" +#include "src/runtime/kernel/arm/opclib/fp32/batch_to_space.h" +#include "src/runtime/kernel/arm/opclib/fp32/crop.h" +#include "src/runtime/kernel/arm/fp32/flatten.h" +#include "src/runtime/kernel/arm/opclib/fp32/unsqueeze.h" +#include "src/runtime/kernel/arm/opclib/fp32/one_hot.h" +#include "src/runtime/kernel/arm/opclib/fp32/strided_slice.h" + +namespace mindspore::kernel { +FillParameter *PopulateFillParam(const lite::Primitive *primitive) { + auto param = primitive->Value()->value_as_Fill(); + FillParameter *parameter = new (std::nothrow) FillParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new FillParameter failed."; + return nullptr; + } + auto flatDims = param->dims(); + parameter->num_dims_ = flatDims->size(); + int i = 0; + for (auto iter = flatDims->begin(); iter != flatDims->end(); iter++) { + parameter->dims_[i++] = *iter; + } + return parameter; +} + +ExpandDimsParameter *PopulateExpandDimsParam(const lite::Primitive *primitive) { + auto param = primitive->Value()->value_as_ExpandDims(); + ExpandDimsParameter *parameter = new (std::nothrow) ExpandDimsParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new ExpandDimsParameter failed."; + return nullptr; + } + parameter->dim_ = param->dim(); + return parameter; +} + +PoolingParameter *PopulatePoolingParam(const lite::Primitive *primitive) { + auto pooling_primitive = primitive->Value()->value_as_Pooling(); + // todo use malloc instead + PoolingParameter *parameter = new (std::nothrow) PoolingParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new PoolingParameter failed."; + return nullptr; + } + parameter->global_ = pooling_primitive->global(); + parameter->window_w_ = pooling_primitive->windowW(); + parameter->window_h_ = pooling_primitive->windowH(); + // todo format + auto pooling_lite_primitive = (lite::Pooling *)primitive; + MS_ASSERT(nullptr != pooling_lite_primitive); + parameter->pad_u_ = pooling_lite_primitive->PadUp(); + parameter->pad_d_ = pooling_lite_primitive->PadDown(); + parameter->pad_l_ = pooling_lite_primitive->PadLeft(); + parameter->pad_r_ = pooling_lite_primitive->PadRight(); + parameter->stride_w_ = pooling_primitive->strideW(); + parameter->stride_h_ = pooling_primitive->strideH(); + + auto pool_mode = pooling_primitive->poolingMode(); + switch (pool_mode) { + case schema::PoolMode_MAX_POOLING: + parameter->max_pooling_ = true; + parameter->avg_pooling_ = false; + break; + case schema::PoolMode_MEAN_POOLING: + parameter->max_pooling_ = false; + parameter->avg_pooling_ = true; + break; + default: + parameter->max_pooling_ = false; + parameter->avg_pooling_ = false; + break; + } + + auto round_mode = pooling_primitive->roundMode(); + switch (round_mode) { + case schema::RoundMode_FLOOR: + parameter->round_floor_ = true; + parameter->round_ceil_ = false; + break; + case schema::RoundMode_CEIL: + parameter->round_floor_ = false; + parameter->round_ceil_ = true; + break; + default: + parameter->round_floor_ = false; + parameter->round_ceil_ = false; + break; + } + return parameter; +} + +MatMulParameter *PopulateFullconnectionParameter(const lite::Primitive *primitive) { + auto param = primitive->Value()->value_as_FullConnection(); + MatMulParameter *parameter = new (std::nothrow) MatMulParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new FullconnectionParameter failed."; + return nullptr; + } + parameter->b_transpose_ = true; + parameter->a_transpose_ = false; + parameter->has_bias_ = param->hasBias(); + parameter->minf_ = -FLT_MAX; + parameter->maxf_ = FLT_MAX; + return parameter; +} + +MatMulParameter *PopulateMatMulParameter(const lite::Primitive *primitive) { + auto param = primitive->Value()->value_as_MatMul(); + MatMulParameter *parameter = new (std::nothrow) MatMulParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new FullconnectionParameter failed."; + return nullptr; + } + parameter->b_transpose_ = param->transposeB(); + parameter->a_transpose_ = param->transposeA(); + parameter->has_bias_ = false; + parameter->minf_ = -FLT_MAX; + parameter->maxf_ = FLT_MAX; + return parameter; +} + +ConvParameter *PopulateConvParameter(const lite::Primitive *primitive) { + ConvParameter *parameter = new (std::nothrow) ConvParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new ConvParameter failed."; + return nullptr; + } + auto conv_primitive = primitive->Value()->value_as_Conv2D(); + parameter->kernel_h_ = conv_primitive->kernelH(); + parameter->kernel_w_ = conv_primitive->kernelW(); + // todo format + parameter->group_ = conv_primitive->group(); + parameter->stride_h_ = conv_primitive->strideH(); + parameter->stride_w_ = conv_primitive->strideW(); + + auto conv2d_lite_primitive = (lite::Conv2D *)primitive; + MS_ASSERT(nullptr != conv2d_lite_primitive); + parameter->pad_u_ = conv2d_lite_primitive->PadUp(); + parameter->pad_d_ = conv2d_lite_primitive->PadDown(); + parameter->pad_l_ = conv2d_lite_primitive->PadLeft(); + parameter->pad_r_ = conv2d_lite_primitive->PadRight(); + parameter->pad_h_ = conv2d_lite_primitive->PadUp(); + parameter->pad_w_ = conv2d_lite_primitive->PadLeft(); + parameter->dilation_h_ = conv_primitive->dilateH(); + parameter->dilation_w_ = conv_primitive->dilateW(); + parameter->input_channel_ = conv_primitive->channelIn(); + parameter->output_channel_ = conv_primitive->channelOut(); + parameter->group_ = conv_primitive->group(); + auto act_type = conv_primitive->activationType(); + switch (act_type) { + case schema::ActivationType_RELU: + parameter->is_relu_ = true; + parameter->is_relu6_ = false; + break; + case schema::ActivationType_RELU6: + parameter->is_relu_ = false; + parameter->is_relu6_ = true; + break; + default: + parameter->is_relu_ = false; + parameter->is_relu6_ = false; + break; + } + return parameter; +} + +ConvParameter *PopulateConvDwParameter(const lite::Primitive *primitive) { + ConvParameter *parameter = new (std::nothrow) ConvParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new ConvParameter failed."; + return nullptr; + } + auto conv_primitive = primitive->Value()->value_as_DepthwiseConv2D(); + parameter->kernel_h_ = conv_primitive->kernelH(); + parameter->kernel_w_ = conv_primitive->kernelW(); + // todo format, group + parameter->stride_h_ = conv_primitive->strideH(); + parameter->stride_w_ = conv_primitive->strideW(); + + auto pad_mode = conv_primitive->padMode(); + auto convdw_lite_primitive = (lite::DepthwiseConv2D *)primitive; + MS_ASSERT(nullptr != convdw_lite_primitive); + parameter->pad_u_ = convdw_lite_primitive->PadUp(); + parameter->pad_d_ = convdw_lite_primitive->PadDown(); + parameter->pad_l_ = convdw_lite_primitive->PadLeft(); + parameter->pad_r_ = convdw_lite_primitive->PadRight(); + parameter->pad_h_ = convdw_lite_primitive->PadUp(); + parameter->pad_w_ = convdw_lite_primitive->PadLeft(); + parameter->dilation_h_ = conv_primitive->dilateH(); + parameter->dilation_w_ = conv_primitive->dilateW(); + auto act_type = conv_primitive->activationType(); + switch (act_type) { + case schema::ActivationType_RELU: + parameter->is_relu_ = true; + parameter->is_relu6_ = false; + break; + case schema::ActivationType_RELU6: + parameter->is_relu_ = false; + parameter->is_relu6_ = true; + break; + default: + parameter->is_relu_ = false; + parameter->is_relu6_ = false; + break; + } + return parameter; +} + +ConvParameter *PopulateDeconvDwParameter(const lite::Primitive *primitive) { + ConvParameter *parameter = new ConvParameter(); + auto conv_primitive = primitive->Value()->value_as_DeDepthwiseConv2D(); + parameter->kernel_h_ = conv_primitive->kernelH(); + parameter->kernel_w_ = conv_primitive->kernelW(); + // todo format, group + parameter->stride_h_ = conv_primitive->strideH(); + parameter->stride_w_ = conv_primitive->strideW(); + + auto deconvdw_lite_primitive = (lite::DeconvDepthwiseConv2D *)primitive; + MS_ASSERT(nullptr != deconvdw_lite_primitive); + parameter->pad_u_ = deconvdw_lite_primitive->PadUp(); + parameter->pad_d_ = deconvdw_lite_primitive->PadDown(); + parameter->pad_l_ = deconvdw_lite_primitive->PadLeft(); + parameter->pad_r_ = deconvdw_lite_primitive->PadRight(); + parameter->pad_h_ = deconvdw_lite_primitive->PadUp(); + parameter->pad_w_ = deconvdw_lite_primitive->PadLeft(); + parameter->dilation_h_ = conv_primitive->dilateH(); + parameter->dilation_w_ = conv_primitive->dilateW(); + auto act_type = conv_primitive->activationType(); + switch (act_type) { + case schema::ActivationType_RELU: + parameter->is_relu_ = true; + parameter->is_relu6_ = false; + break; + case schema::ActivationType_RELU6: + parameter->is_relu_ = false; + parameter->is_relu6_ = true; + break; + default: + parameter->is_relu_ = false; + parameter->is_relu6_ = false; + break; + } + return parameter; +} + +SoftmaxParameter *PopulateSoftmaxParameter(const lite::Primitive *primitive) { + auto softmax_primitive = primitive->Value()->value_as_SoftMax(); + SoftmaxParameter *parameter = new (std::nothrow) SoftmaxParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new SoftmaxParameter failed."; + return nullptr; + } + parameter->axis_ = softmax_primitive->axis(); + return parameter; +} + +ReduceParameter *PopulateReduceParameter(const lite::Primitive *primitive) { + ReduceParameter *parameter = new (std::nothrow) ReduceParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new ReduceParameter failed."; + return nullptr; + } + auto reduce = primitive->Value()->value_as_Reduce(); + parameter->keep_dims_ = reduce->keepDims(); + auto axisVector = reduce->axes(); + if (axisVector->size() > REDUCE_MAX_AXES_NUM) { + MS_LOG(ERROR) << "Reduce axes size " << axisVector->size() << " exceed limit " << REDUCE_MAX_AXES_NUM; + delete (parameter); + return nullptr; + } + parameter->num_axes_ = static_cast(axisVector->size()); + int i = 0; + for (auto iter = axisVector->begin(); iter != axisVector->end(); iter++) { + parameter->axes_[i++] = *iter; + } + parameter->mode_ = static_cast(reduce->mode()); + return parameter; +} + +PadParameter *PopulatePadParameter(const lite::Primitive *primitive) { + PadParameter *parameter = new (std::nothrow) PadParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new PadParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_Pad(); + auto size = param->paddings()->size(); + parameter->ori_size_ = size; + auto valid_size = size <= 8 ? size : 8; + for (size_t i = 0; i < valid_size; i++) { + parameter->paddings[i] = (*(param->paddings()))[i]; + } + return parameter; +} + +ActivationParameter *PopulateActivationParameter(const lite::Primitive *primitive) { + ActivationParameter *parameter = new (std::nothrow) ActivationParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new ActivationParameter failed."; + return nullptr; + } + auto activation = primitive->Value()->value_as_Activation(); + parameter->type_ = static_cast(activation->type()); + return parameter; +} + +FusedBatchNormParameter *PopulateFusedBatchNorm(const lite::Primitive *primitive) { + FusedBatchNormParameter *parameter = new (std::nothrow) FusedBatchNormParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new FusedBatchNormParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_FusedBatchNorm(); + parameter->epsilon_ = param->epsilon(); + return parameter; +} + +ArithmeticParameter *PopulateArithmetic(const lite::Primitive *primitive) { + ArithmeticParameter *parameter = new (std::nothrow) ArithmeticParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new ArithmeticParameter failed."; + return nullptr; + } + parameter->op_parameter.type_ = primitive->Type(); + parameter->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); + parameter->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); + auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); + (void)memcpy(parameter->in_shape0_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); + (void)memcpy(parameter->in_shape1_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); + (void)memcpy(parameter->out_shape_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + return parameter; +} + +ArithmeticParameter *PopulateEltwiseParam(const lite::Primitive *primitive) { + ArithmeticParameter *parameter = new (std::nothrow) ArithmeticParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new ArithmeticParameter failed."; + return nullptr; + } + auto eltwise = primitive->Value()->value_as_Eltwise(); + switch (eltwise->mode()) { + case schema::EltwiseMode_PROD: + parameter->op_parameter.type_ = schema::PrimitiveType_Mul; + break; + case schema::EltwiseMode_SUM: + parameter->op_parameter.type_ = schema::PrimitiveType_Add; + break; + case schema::EltwiseMode_MAXIMUM: + parameter->op_parameter.type_ = schema::PrimitiveType_Maximum; + break; + default: + delete parameter; + return nullptr; + } + return parameter; +} + +ArithmeticSelfParameter *PopulateArithmeticSelf(const lite::Primitive *primitive) { + ArithmeticSelfParameter *parameter = new (std::nothrow) ArithmeticSelfParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new ArithmeticParameter failed."; + return nullptr; + } + parameter->op_parameter_.type_ = primitive->Type(); + return parameter; +} + +PowerParameter *PopulatePowerParameter(const lite::Primitive *primitive) { + PowerParameter *parameter = new (std::nothrow) PowerParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new PowerParameter failed."; + return nullptr; + } + auto power = primitive->Value()->value_as_Power(); + parameter->power_ = power->power(); + parameter->scale_ = power->scale(); + parameter->shift_ = power->shift(); + return parameter; +} + +ArgMinMaxParameter *PopulateArgMinMaxParam(const lite::Primitive *primitive) { + ArgMinMaxParameter *parameter = new (std::nothrow) ArgMinMaxParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new ArgMinMaxParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_ArgMax(); + parameter->op_parameter_.type_ = primitive->Type(); + parameter->axis_ = param->axis(); + parameter->topk_ = param->topK(); + parameter->axis_type_ = param->axisType(); + parameter->out_value_ = param->outMaxValue(); + parameter->keep_dims_ = param->keepDims(); + return parameter; +} + +CastParameter *PopulateCastParam(const lite::Primitive *primitive) { + CastParameter *parameter = new (std::nothrow) CastParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new CastParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_Cast(); + parameter->op_parameter_.type_ = primitive->Type(); + parameter->src_type_ = param->srcT(); + parameter->dst_type_ = param->dstT(); + return parameter; +} + +LocalResponseNormParameter *PopulateLocalResponseNormParameter(const lite::Primitive *primitive) { + auto local_response_norm_attr = primitive->Value()->value_as_LocalResponseNormalization(); + LocalResponseNormParameter *parameter = new (std::nothrow) LocalResponseNormParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new LocalResponseNormParameter failed."; + return nullptr; + } + parameter->depth_radius_ = local_response_norm_attr->depth_radius(); + parameter->bias_ = local_response_norm_attr->bias(); + parameter->alpha_ = local_response_norm_attr->alpha(); + parameter->beta_ = local_response_norm_attr->beta(); + return parameter; +} + +RangeParameter *PopulateRangeParameter(const lite::Primitive *primitive) { + auto range_attr = primitive->Value()->value_as_Range(); + RangeParameter *parameter = new (std::nothrow) RangeParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new RangeParameter failed."; + return nullptr; + } + parameter->start_ = range_attr->start(); + parameter->limit_ = range_attr->limit(); + parameter->delta_ = range_attr->delta(); + parameter->dType_ = range_attr->dType(); + return parameter; +} + +OpParameter *PopulateCeilParameter(const lite::Primitive *primitive) { + OpParameter *parameter = new (std::nothrow) OpParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new OpParameter failed."; + return nullptr; + } + parameter->type_ = primitive->Type(); + return parameter; +} + +ConcatParameter *PopulateConcatParameter(const lite::Primitive *primitive) { + ConcatParameter *parameter = new (std::nothrow) ConcatParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new ConcatParameter failed."; + return nullptr; + } + parameter->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_Concat(); + parameter->axis_ = param->axis(); + return parameter; +} + +TileParameter *PopulateTileParameter(const lite::Primitive *primitive) { + TileParameter *parameter = new (std::nothrow) TileParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new TileParameter failed."; + return nullptr; + } + parameter->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_Tile(); + auto multiples = param->multiples(); + parameter->in_dim_ = multiples->size(); + for (size_t i = 0; i < parameter->in_dim_; ++i) { + parameter->multiples_[i] = multiples->Get(i); + } + return parameter; +} + +TopkParameter *PopulateTopKParameter(const lite::Primitive *primitive) { + TopkParameter *parameter = new (std::nothrow) TopkParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new TopkParameter failed."; + return nullptr; + } + parameter->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_TopK(); + parameter->k_ = param->k(); + parameter->sorted_ = param->sorted(); + return parameter; +} + +OpParameter *PopulateNhwc2NchwParameter(const lite::Primitive *primitive) { + OpParameter *parameter = new (std::nothrow) OpParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new Nhwc2NchwParameter failed."; + return nullptr; + } + parameter->type_ = primitive->Type(); + return parameter; +} + +OpParameter *PopulateNchw2NhwcParameter(const lite::Primitive *primitive) { + OpParameter *parameter = new (std::nothrow) OpParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new Nchw2NhwcParameter failed."; + return nullptr; + } + parameter->type_ = primitive->Type(); + return parameter; +} + +TransposeParameter *PopulateTransposeParameter(const lite::Primitive *primitive) { + TransposeParameter *parameter = new (std::nothrow) TransposeParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new TransposeParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_Transpose(); + parameter->op_parameter_.type_ = primitive->Type(); + auto perm_vector_ = param->perm(); + int i = 0; + for (auto iter = perm_vector_->begin(); iter != perm_vector_->end(); iter++) { + parameter->perm_[i++] = *iter; + } + parameter->num_axes_ = i; + parameter->conjugate_ = param->conjugate(); + return parameter; +} + +SplitParameter *PopulateSplitParameter(const lite::Primitive *primitive) { + SplitParameter *parameter = new (std::nothrow) SplitParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new SplitParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_Split(); + parameter->op_parameter_.type_ = primitive->Type(); + parameter->num_split_ = param->numberSplit(); + auto split_sizes_vector_ = param->sizeSplits(); + int i = 0; + for (auto iter = split_sizes_vector_->begin(); iter != split_sizes_vector_->end(); iter++) { + parameter->split_sizes_[i++] = *iter; + } + parameter->split_dim_ = param->splitDim(); + parameter->num_split_ = param->numberSplit(); + return parameter; +} + +SqueezeParameter *PopulateSqueezeParameter(const lite::Primitive *primitive) { + SqueezeParameter *parameter = new (std::nothrow) SqueezeParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new SqueezeParameter failed."; + return nullptr; + } + parameter->op_parameter_.type_ = primitive->Type(); + return parameter; +} + +ScaleParameter *PopulateScaleParameter(const lite::Primitive *primitive) { + if (primitive == nullptr) { + MS_LOG(ERROR) << "input primitive is nullptr"; + return nullptr; + } + ScaleParameter *parameter = new (std::nothrow) ScaleParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new ScaleParameter failed."; + return nullptr; + } + parameter->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_Scale(); + if (param == nullptr) { + MS_LOG(ERROR) << "value_as_Scale return nullptr"; + return nullptr; + } + // NCHW todo use enum + if (param->format() == schema::Format_NCHW) { + parameter->axis_ = 1; + parameter->num_axis_ = 1; + } else if (param->format() == schema::Format_NHWC) { + parameter->axis_ = 3; + parameter->num_axis_ = 1; + } + + return parameter; +} + +GatherParameter *PopulateGatherParameter(const lite::Primitive *primitive) { + auto gather_attr = primitive->Value()->value_as_Gather(); + GatherParameter *parameter = new (std::nothrow) GatherParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new GatherParameter failed."; + return nullptr; + } + parameter->axis_ = gather_attr->axis(); + parameter->batchDims_ = gather_attr->batchDims(); + return parameter; +} + +GatherNdParameter *PopulateGatherNdParameter(const lite::Primitive *primitive) { + GatherNdParameter *parameter = new (std::nothrow) GatherNdParameter(); + MS_ASSERT(paramter != nullptr); + auto gatherNd_attr = primitive->Value()->value_as_GatherNd(); + parameter->batchDims_ = gatherNd_attr->batchDims(); + return parameter; +} + +ScatterNDParameter *PopulateScatterNDParameter(const lite::Primitive *primitive) { + ScatterNDParameter *parameter = new (std::nothrow) ScatterNDParameter(); + MS_ASSERT(paramter != nullptr); + return parameter; +} + +SliceParameter *PopulateSliceParam(const lite::Primitive *primitive) { + SliceParameter *parameter = new (std::nothrow) SliceParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new SliceParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_Slice(); + parameter->op_parameter_.type_ = primitive->Type(); + auto param_begin = param->begin(); + auto param_size = param->size(); + if (param_begin->size() != param_size->size()) { + delete parameter; + return nullptr; + } + parameter->param_length_ = static_cast(param_begin->size()); + for (int32_t i = 0; i < parameter->param_length_; ++i) { + parameter->begin_[i] = param_begin->Get(i); + parameter->size_[i] = param_size->Get(i); + } + return parameter; +} + +BroadcastToParameter *PopulateBroadcastToParam(const lite::Primitive *primitive) { + BroadcastToParameter *parameter = new (std::nothrow) BroadcastToParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new BroadcastToParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_BroadcastTo(); + parameter->op_parameter_.type_ = primitive->Type(); + auto dst_shape = param->dst_shape(); + parameter->shape_size_ = dst_shape->size(); + for (size_t i = 0; i < parameter->shape_size_; ++i) { + parameter->shape_[i] = dst_shape->Get(i); + } + return parameter; +} + +ReshapeParameter *PopulateReshapeParam(const lite::Primitive *primitive) { + ReshapeParameter *parameter = new (std::nothrow) ReshapeParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new ReshapeParameter failed."; + return nullptr; + } + parameter->op_parameter_.type_ = primitive->Type(); + return parameter; +} + +ReverseParameter *PopulateReverseParameter(const lite::Primitive *primitive) { + auto reverse_attr = primitive->Value()->value_as_Reverse(); + ReverseParameter *parameter = new (std::nothrow) ReverseParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new ReverseParameter failed."; + return nullptr; + } + auto flatAxis = reverse_attr->axis(); + parameter->num_axis_ = flatAxis->size(); + int i = 0; + for (auto iter = flatAxis->begin(); iter != flatAxis->end(); iter++) { + parameter->axis_[i++] = *iter; + } + return parameter; +} + +UnsqueezeParameter *PopulateUnsqueezeParameter(const lite::Primitive *primitive) { + auto unsqueeze_attr = primitive->Value()->value_as_Unsqueeze(); + UnsqueezeParameter *parameter = new (std::nothrow) UnsqueezeParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new ReverseParameter failed."; + return nullptr; + } + auto flatAxis = unsqueeze_attr->axis(); + parameter->num_dim_ = flatAxis->size(); + int i = 0; + for (auto iter = flatAxis->begin(); iter != flatAxis->end(); iter++) { + parameter->dims_[i++] = *iter; + } + return parameter; +} + +StackParameter *PopulateStackParam(const lite::Primitive *primitive) { + StackParameter *parameter = new (std::nothrow) StackParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new StackParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_Stack(); + parameter->op_parameter_.type_ = primitive->Type(); + parameter->axis_ = param->axis(); + return parameter; +} + +UnstackParameter *PopulateUnstackParam(const lite::Primitive *primitive) { + UnstackParameter *parameter = new (std::nothrow) UnstackParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new UnstackParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_Unstack(); + parameter->op_parameter_.type_ = primitive->Type(); + parameter->num_ = param->num(); + parameter->axis_ = param->axis(); + return parameter; +} + +ReverseSequenceParameter *PopulateReverseSequenceParam(const lite::Primitive *primitive) { + ReverseSequenceParameter *parameter = new (std::nothrow) ReverseSequenceParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new ReverseSequenceParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_ReverseSequence(); + parameter->op_parameter_.type_ = primitive->Type(); + parameter->seq_axis_ = param->seqAxis(); + parameter->batch_axis_ = param->batchAxis(); + return parameter; +} + +UniqueParameter *PopulateUniqueParam(const lite::Primitive *primitive) { + UniqueParameter *parameter = new (std::nothrow) UniqueParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new PopulateUniqueParam failed."; + return nullptr; + } + parameter->op_parameter_.type_ = primitive->Type(); + return parameter; +} + +DepthToSpaceParameter *PopulateDepthToSpaceParam(const lite::Primitive *primitive) { + DepthToSpaceParameter *parameter = new (std::nothrow) DepthToSpaceParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new DepthToSpaceParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_DepthToSpace(); + parameter->op_parameter_.type_ = primitive->Type(); + parameter->block_size_ = param->blockSize(); + return parameter; +} + +ResizeParameter *PopulateResizeParameter(const lite::Primitive *primitive) { + ResizeParameter *parameter = new (std::nothrow) ResizeParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new ResizeParameter failed."; + return nullptr; + } + auto param = primitive->Value()->value_as_Resize(); + parameter->method_ = param->method(); + parameter->new_height_ = param->newHeight(); + parameter->new_width_ = param->newWidth(); + parameter->align_corners_ = param->alignCorners(); + parameter->preserve_aspect_ratio_ = param->preserveAspectRatio(); + return parameter; +} + +BatchToSpaceParameter *PopulateBatchToSpaceParameter(const lite::Primitive *primitive) { + BatchToSpaceParameter *parameter = new (std::nothrow) BatchToSpaceParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "New BatchToSpaceParameter fail!"; + return nullptr; + } + auto param = primitive->Value()->value_as_BatchToSpace(); + auto block_shape = param->blockShape(); + if (block_shape->size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) { + MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; + return nullptr; + } + + auto crops = param->crops(); + if (crops->size() != BATCH_TO_SPACE_CROPS_SIZE) { + MS_LOG(ERROR) << "batch_to_space crops size should be " << BATCH_TO_SPACE_CROPS_SIZE; + return nullptr; + } + + for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) { + parameter->block_shape_[i] = block_shape->Get(i); + } + + for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) { + parameter->crops_[i] = crops->Get(i); + } + return parameter; +} + +CropParameter *PopulateCropParameter(const lite::Primitive *primitive) { + auto param = primitive->Value()->value_as_Crop(); + auto param_offset = param->offsets(); + if (param_offset->size() > CROP_OFFSET_MAX_SIZE) { + MS_LOG(ERROR) << "parameter offset size(" << param_offset->size() << ") should <= " << CROP_OFFSET_MAX_SIZE; + return nullptr; + } + CropParameter *parameter = new (std::nothrow) CropParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new CropParameter fail!"; + return nullptr; + } + parameter->axis_ = param->axis(); + for (int i = 0; i < param_offset->size(); ++i) { + parameter->offset_[i] = param_offset->Get(i); + } + return parameter; +} + +OneHotParameter *PopulateOneHotParameter(const lite::Primitive *primitive) { + OneHotParameter *parameter = new (std::nothrow) OneHotParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new OneHotParameter fail!"; + return nullptr; + } + auto param = primitive->Value()->value_as_OneHot(); + if (param == nullptr) { + delete (parameter); + MS_LOG(ERROR) << "get OneHot param nullptr."; + return nullptr; + } + parameter->axis_ = param->axis(); + return parameter; +} + +FlattenParameter *PopulateFlattenParameter(const lite::Primitive *primitive) { + FlattenParameter *parameter = new (std::nothrow) FlattenParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new FlattenParameter fail!"; + return nullptr; + } + return parameter; +} + +StridedSliceParameter *PopulateStridedSliceParam(const lite::Primitive *primitive) { + StridedSliceParameter *parameter = new (std::nothrow) StridedSliceParameter(); + if (parameter == nullptr) { + MS_LOG(ERROR) << "new StridedSliceParameter failed."; + return nullptr; + } + parameter->op_parameter_.type_ = primitive->Type(); + auto n_dims = ((lite::StridedSlice *)primitive)->NDims(); + parameter->num_axes_ = n_dims; + auto begin = ((lite::StridedSlice *)primitive)->UpdatedBegins(); + (void)memcpy(parameter->begins_, (begin.data()), begin.size() * sizeof(int)); + auto end = ((lite::StridedSlice *)primitive)->UpdatedEnds(); + (void)memcpy(parameter->ends_, (end.data()), end.size() * sizeof(int)); + auto stride = ((lite::StridedSlice *)primitive)->UpdatedStrides(); + (void)memcpy(parameter->strides_, (stride.data()), stride.size() * sizeof(int)); + auto in_shape = ((lite::StridedSlice *)primitive)->UpdatedInShape(); + (void)memcpy(parameter->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int)); + return parameter; +} + +OpParameter *PopulateParameter(const lite::Primitive *primitive) { + MS_EXCEPTION_IF_NULL(primitive); + auto op_type = primitive->Type(); + switch (op_type) { + case schema::PrimitiveType_SoftMax: + return reinterpret_cast(PopulateSoftmaxParameter(primitive)); + case schema::PrimitiveType_Activation: + return reinterpret_cast(PopulateActivationParameter(primitive)); + case schema::PrimitiveType_Conv2D: + return reinterpret_cast(PopulateConvParameter(primitive)); + case schema::PrimitiveType_Reduce: + return reinterpret_cast(PopulateReduceParameter(primitive)); + case schema::PrimitiveType_Pooling: + return reinterpret_cast(PopulatePoolingParam(primitive)); + case schema::PrimitiveType_DepthwiseConv2D: + return reinterpret_cast(PopulateConvDwParameter(primitive)); + case schema::PrimitiveType_DeDepthwiseConv2D: + return reinterpret_cast(PopulateDeconvDwParameter(primitive)); + case schema::PrimitiveType_FusedBatchNorm: + return reinterpret_cast(PopulateFusedBatchNorm(primitive)); + case schema::PrimitiveType_FullConnection: + return reinterpret_cast(PopulateFullconnectionParameter(primitive)); + case schema::PrimitiveType_Power: + return reinterpret_cast(PopulatePowerParameter(primitive)); + case schema::PrimitiveType_LocalResponseNormalization: + return reinterpret_cast(PopulateLocalResponseNormParameter(primitive)); + case schema::PrimitiveType_Range: + return reinterpret_cast(PopulateRangeParameter(primitive)); + case schema::PrimitiveType_Transpose: + return reinterpret_cast(PopulateTransposeParameter(primitive)); + case schema::PrimitiveType_Mul: + case schema::PrimitiveType_Add: + case schema::PrimitiveType_Sub: + case schema::PrimitiveType_Div: + case schema::PrimitiveType_FloorDiv: + case schema::PrimitiveType_FloorMod: + case schema::PrimitiveType_SquaredDifference: + return reinterpret_cast(PopulateArithmetic(primitive)); + case schema::PrimitiveType_BiasAdd: + return reinterpret_cast(new ArithmeticParameter()); + case schema::PrimitiveType_Eltwise: + return reinterpret_cast(PopulateEltwiseParam(primitive)); + case schema::PrimitiveType_ExpandDims: + return reinterpret_cast(PopulateExpandDimsParam(primitive)); + case schema::PrimitiveType_Abs: + case schema::PrimitiveType_Cos: + case schema::PrimitiveType_Sin: + case schema::PrimitiveType_Exp: + case schema::PrimitiveType_Log: + case schema::PrimitiveType_Square: + case schema::PrimitiveType_Sqrt: + case schema::PrimitiveType_Rsqrt: + case schema::PrimitiveType_LogicalNot: + case schema::PrimitiveType_Floor: + return reinterpret_cast(PopulateArithmeticSelf(primitive)); + case schema::PrimitiveType_ArgMax: + case schema::PrimitiveType_ArgMin: + return reinterpret_cast(PopulateArgMinMaxParam(primitive)); + case schema::PrimitiveType_Cast: + return reinterpret_cast(PopulateCastParam(primitive)); + case schema::PrimitiveType_Ceil: + return reinterpret_cast(PopulateCeilParameter(primitive)); + case schema::PrimitiveType_Scale: + return reinterpret_cast(PopulateScaleParameter(primitive)); + case schema::PrimitiveType_Reshape: + return reinterpret_cast(PopulateReshapeParam(primitive)); + case schema::PrimitiveType_Concat: + return reinterpret_cast(PopulateConcatParameter(primitive)); + case schema::PrimitiveType_Tile: + return reinterpret_cast(PopulateTileParameter(primitive)); + case schema::PrimitiveType_TopK: + return reinterpret_cast(PopulateTopKParameter(primitive)); + case schema::PrimitiveType_Fill: + return reinterpret_cast(PopulateFillParam(primitive)); + case schema::PrimitiveType_Gather: + return reinterpret_cast(PopulateGatherParameter(primitive)); + case schema::PrimitiveType_GatherNd: + return reinterpret_cast(PopulateGatherNdParameter(primitive)); + case schema::PrimitiveType_Slice: + return reinterpret_cast(PopulateSliceParam(primitive)); + case schema::PrimitiveType_BroadcastTo: + return reinterpret_cast(PopulateBroadcastToParam(primitive)); + case schema::PrimitiveType_Reverse: + return reinterpret_cast(PopulateReverseParameter(primitive)); + case schema::PrimitiveType_Stack: + return reinterpret_cast(PopulateStackParam(primitive)); + case schema::PrimitiveType_Unstack: + return reinterpret_cast(PopulateUnstackParam(primitive)); + case schema::PrimitiveType_ReverseSequence: + return reinterpret_cast(PopulateReverseSequenceParam(primitive)); + case schema::PrimitiveType_Unique: + return reinterpret_cast(PopulateUniqueParam(primitive)); + case schema::PrimitiveType_DepthToSpace: + return reinterpret_cast(PopulateDepthToSpaceParam(primitive)); + case schema::PrimitiveType_Nchw2Nhwc: + return reinterpret_cast(PopulateNchw2NhwcParameter(primitive)); + case schema::PrimitiveType_Nhwc2Nchw: + return reinterpret_cast(PopulateNhwc2NchwParameter(primitive)); + case schema::PrimitiveType_Pad: + return reinterpret_cast(PopulatePadParameter(primitive)); + case schema::PrimitiveType_Resize: + return reinterpret_cast(PopulateResizeParameter(primitive)); + case schema::PrimitiveType_BatchToSpace: + return reinterpret_cast(PopulateBatchToSpaceParameter(primitive)); + case schema::PrimitiveType_Crop: + return reinterpret_cast(PopulateCropParameter(primitive)); + case schema::PrimitiveType_Unsqueeze: + return reinterpret_cast(PopulateUnsqueezeParameter(primitive)); + case schema::PrimitiveType_Flatten: + return reinterpret_cast(PopulateFlattenParameter(primitive)); + case schema::PrimitiveType_MatMul: + return reinterpret_cast(PopulateMatMulParameter(primitive)); + case schema::PrimitiveType_OneHot: + return reinterpret_cast(PopulateOneHotParameter(primitive)); + default: + break; + } + return nullptr; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/populate_parameter.h b/mindspore/lite/src/populate_parameter.h new file mode 100644 index 00000000000..f2e9ab6afeb --- /dev/null +++ b/mindspore/lite/src/populate_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_POPULATE_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_POPULATE_PARAMETER_H_ + +#include "schema/model_generated.h" +#include "src/ops/ops.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" + +namespace mindspore::kernel { +OpParameter *PopulateParameter(const lite::Primitive *primitive); +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_POPULATE_PARAMETER_H_ + diff --git a/mindspore/lite/src/runtime/allocator.cc b/mindspore/lite/src/runtime/allocator.cc new file mode 100644 index 00000000000..3b047e66904 --- /dev/null +++ b/mindspore/lite/src/runtime/allocator.cc @@ -0,0 +1,123 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/allocator.h" +#include +#include "utils/log_adapter.h" + +namespace mindspore::lite { +std::shared_ptr Allocator::Create() { return std::shared_ptr(new DefaultAllocator()); } + +DefaultAllocator::DefaultAllocator() {} + +DefaultAllocator::~DefaultAllocator() { Clear(); } + +void DefaultAllocator::SetContext(const AllocatorContext &ctx) { + lockFlag = ctx.lockFlag; + shiftFactor = ctx.shiftFactor; +} + +void DefaultAllocator::Lock() { + if (lockFlag) { + lock.lock(); + } +} + +void DefaultAllocator::UnLock() { + if (lockFlag) { + lock.unlock(); + } +} + +void *DefaultAllocator::Malloc(size_t size) { + if (size > MAX_MALLOC_SIZE) { + MS_LOG(ERROR) << "MallocData out of max_size, size: " << size; + return nullptr; + } + Lock(); + auto iter = freeList.lower_bound(size); + if (iter != freeList.end() && (iter->second->size >= size) && (iter->second->size < (size << shiftFactor))) { + auto membuf = iter->second; + freeList.erase(iter); + allocatedList[membuf->buf] = membuf; + UnLock(); + return membuf->buf; + } + + std::unique_ptr membuf(reinterpret_cast(malloc(sizeof(MemBuf) + size))); + if (membuf == nullptr) { + MS_LOG(ERROR) << "malloc membuf return nullptr"; + UnLock(); + return nullptr; + } + membuf->size = size; + membuf->buf = reinterpret_cast(membuf.get()) + sizeof(MemBuf); + auto bufPtr = membuf->buf; + allocatedList[bufPtr] = membuf.release(); + UnLock(); + return bufPtr; +} + +void DefaultAllocator::Free(void *buf) { + if (buf == nullptr) { + return; + } + Lock(); + auto iter = allocatedList.find(buf); + if (iter != allocatedList.end()) { + auto membuf = iter->second; + allocatedList.erase(iter); + freeList.insert(std::make_pair(membuf->size, membuf)); + UnLock(); + return; + } + UnLock(); + free(buf); +} + +size_t DefaultAllocator::GetTotalSize() { + Lock(); + size_t totalSize = 0; + + for (auto it = allocatedList.begin(); it != allocatedList.end(); it++) { + auto membuf = it->second; + totalSize += membuf->size; + } + + for (auto it = freeList.begin(); it != freeList.end(); it++) { + auto membuf = it->second; + totalSize += membuf->size; + } + UnLock(); + return totalSize; +} + +void DefaultAllocator::Clear() { + Lock(); + + for (auto it = allocatedList.begin(); it != allocatedList.end(); it++) { + free(it->second); + } + allocatedList.clear(); + + for (auto it = freeList.begin(); it != freeList.end(); it++) { + free(it->second); + } + freeList.clear(); + UnLock(); +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/runtime/allocator.h b/mindspore/lite/src/runtime/allocator.h new file mode 100644 index 00000000000..9d44bd6f891 --- /dev/null +++ b/mindspore/lite/src/runtime/allocator.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_ALLOCATOR_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_ALLOCATOR_H_ + +#include +#include +#include +#include +#include +#include +#include + +namespace mindspore::lite { +struct AllocatorContext { + int shiftFactor; + bool lockFlag; +}; + +class Allocator { + public: + Allocator() : name("default") {} + virtual ~Allocator() {} + virtual void *Malloc(size_t size) = 0; + virtual void Free(void *ptr) = 0; + virtual void SetContext(const AllocatorContext &ctx) {} + virtual size_t GetTotalSize() { return 0; } + virtual void Clear() {} + static std::shared_ptr Create(); + std::string name; +}; + +class DefaultAllocator : public Allocator { + public: + DefaultAllocator(); + ~DefaultAllocator() override; + void SetContext(const AllocatorContext &ctx) override; + void *Malloc(size_t size) override; + void Free(void *ptr) override; + size_t GetTotalSize() override; + void Clear() override; + + private: + void Lock(); + void UnLock(); + struct MemBuf { + size_t size; + void *buf; + }; + + std::mutex lock; + // buf, membuf> + std::unordered_map allocatedList; + std::multimap freeList; + // 6 is empirical value + int shiftFactor = 6; + bool lockFlag = false; +}; + +#define MAX_MALLOC_SIZE 500 * 1024 * 1024 + +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_RUNTIME_ALLOCATOR_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt new file mode 100644 index 00000000000..2f88b378a34 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt @@ -0,0 +1,34 @@ +file(GLOB_RECURSE KERNEL_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/base/*.cc + ${CMAKE_CURRENT_SOURCE_DIR}/opclib/*.cc + ${CMAKE_CURRENT_SOURCE_DIR}/opclib/fp32/*.cc + ${CMAKE_CURRENT_SOURCE_DIR}/opclib/int8/*.cc + ${CMAKE_CURRENT_SOURCE_DIR}/fp32/*.cc + ${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc + ) + +if (PLATFORM_ARM64) + # assembly + file(GLOB_RECURSE ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm64/*.s + ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm64/*.S) + set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) + set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) +endif() + +if (PLATFORM_ARM32) + # assembly + file(GLOB_RECURSE ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/opclib/assembly/arm32/*.s) + set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) + set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) +endif() + +if (ENABLE_FP16) + file(GLOB_RECURSE FP6_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc + ${CMAKE_CURRENT_SOURCE_DIR}/opclib/fp16/*.cc + ) + set(KERNEL_SRC ${KERNEL_SRC} ${FP6_SRC}) +endif () + +add_library(cpu_kernel_mid_ OBJECT ${KERNEL_SRC}) +add_subdirectory(opclib) diff --git a/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc new file mode 100644 index 00000000000..bfb0689ffc4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/base/concat_base.h" +#include +#include "src/runtime/kernel/arm/int8/concat_int8.h" +#include "src/runtime/kernel/arm/fp32/concat.h" +#include "src/runtime/kernel/arm/opclib/fp32/concat.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "include/context.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Concat; + +namespace mindspore::kernel { +int ConcatBaseCPUKernel::Init() { + axis_ = concat_param_->axis_ >= 0 ? concat_param_->axis_ : inputs_.front()->shape().size() + concat_param_->axis_; + return RET_OK; +} + +kernel::LiteKernel *CpuConcatInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Concat); + auto *kernel = new(std::nothrow) ConcatInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuConcatFp32OrInt32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Concat); + auto *kernel = new(std::nothrow) ConcatCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuConcatKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, + const lite::Context *ctx, const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Concat); + auto input_tensor = inputs.at(kInputIndex); + auto data_type = input_tensor->data_type(); + kernel::LiteKernel *kernel = nullptr; + switch (data_type) { + case kNumberTypeInt8: + case kNumberTypeUInt8: + kernel = CpuConcatInt8KernelCreator(inputs, outputs, opParameter, ctx); + break; + case kNumberTypeInt32: + case kNumberTypeFloat32: + kernel = CpuConcatFp32OrInt32KernelCreator(inputs, outputs, opParameter, ctx); + break; + default: + break; + } + + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Concat, CpuConcatKernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/base/concat_base.h b/mindspore/lite/src/runtime/kernel/arm/base/concat_base.h new file mode 100644 index 00000000000..af0f572ff3f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/concat_base.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONCAT_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONCAT_BASE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/concat_parameter.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class ConcatBaseCPUKernel : public LiteKernel { + public: + ConcatBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) { + opParameter->thread_num_ = ctx->threadNum; + concat_param_ = reinterpret_cast(opParameter); + } + + ~ConcatBaseCPUKernel() = default; + + int Init() override; + + int ReSize() override { return 0; } + + int Run() override { return 0; } + protected: + int thread_count_; + int axis_; + const Context *ctx_; + ConcatParameter *concat_param_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONCAT_BASE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc new file mode 100644 index 00000000000..2a59f544521 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc @@ -0,0 +1,502 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/fp32/convolution.h" +#include "src/runtime/kernel/arm/fp32/convolution_winograd.h" +#include "src/runtime/kernel/arm/fp32/deconvolution.h" +#include "src/runtime/kernel/arm/fp32/convolution_1x1.h" +#include "src/runtime/kernel/arm/fp32/convolution_3x3.h" +#include "src/runtime/kernel/arm/fp32/convolution_depthwise.h" +#include "src/runtime/kernel/arm/fp32/deconvolution_depthwise.h" +#ifdef ENABLE_FP16 +#include "src/runtime/kernel/arm/fp16/convolution_fp16.h" +#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" +#endif +#include "src/runtime/kernel/arm/int8/deconvolution_int8.h" +#include "src/runtime/kernel/arm/int8/convolution_int8.h" +#include "src/runtime/kernel/arm/int8/convolution_3x3_int8.h" +#include "src/runtime/kernel/arm/int8/convolution_depthwise_int8.h" +#include "src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::ActivationType; +using mindspore::schema::PadMode; +using mindspore::schema::PrimitiveType_Conv2D; +using mindspore::schema::PrimitiveType_DeConv2D; +using mindspore::schema::PrimitiveType_DeDepthwiseConv2D; +using mindspore::schema::PrimitiveType_DepthwiseConv2D; + +namespace mindspore::kernel { +ConvolutionBaseCPUKernel::~ConvolutionBaseCPUKernel() { + if (bias_data_ != nullptr) { + free(bias_data_); + } + if (nhwc4_input_ != nullptr) { + free(nhwc4_input_); + } +} + +void ConvolutionBaseCPUKernel::FreeQuantParam() { + if (quant_args_ != nullptr) { + for (int i = 0; i < 3; ++i) { + if (*(quant_args_ + i) != nullptr) { + free(*(quant_args_ + i)); + } + } + } + if (conv_quant_arg_ != nullptr) { + if (conv_quant_arg_->real_multiplier_ != nullptr) { + free(conv_quant_arg_->real_multiplier_); + } + if (conv_quant_arg_->left_shift_ != nullptr) { + free(conv_quant_arg_->left_shift_); + } + if (conv_quant_arg_->right_shift_ != nullptr) { + free(conv_quant_arg_->right_shift_); + } + if (conv_quant_arg_->quant_multiplier_ != nullptr) { + free(conv_quant_arg_->quant_multiplier_); + } + if (conv_quant_arg_->out_act_min_ != nullptr) { + free(conv_quant_arg_->out_act_min_); + } + if (conv_quant_arg_->out_act_max_ != nullptr) { + free(conv_quant_arg_->out_act_max_); + } + free(conv_quant_arg_); + } +} + +int ConvolutionBaseCPUKernel::Init() { + auto input = this->inputs_.front(); + auto output = this->outputs_.front(); + + conv_param_->input_batch_ = input->Batch(); + conv_param_->input_h_ = input->Height(); + conv_param_->input_w_ = input->Width(); + conv_param_->input_channel_ = input->Channel(); + conv_param_->output_batch_ = output->Batch(); + conv_param_->output_h_ = output->Height(); + conv_param_->output_w_ = output->Width(); + conv_param_->output_channel_ = output->Channel(); + conv_param_->thread_num_ = ctx_->threadNum; + + return RET_OK; +} + +int ConvolutionBaseCPUKernel::CheckLayout(lite::tensor::Tensor *input_tensor) { + auto data_type = input_tensor->data_type(); + auto input_format = input_tensor->GetFormat(); + schema::Format execute_format = schema::Format_NHWC4; + convert_func_ = LayoutTransform(data_type, input_format, execute_format); + if (convert_func_ == nullptr) { + MS_LOG(ERROR) << "layout convert func is nullptr."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionBaseCPUKernel::SetQuantParam() { + conv_quant_arg_ = new ConvQuantArg(); + quant_args_ = reinterpret_cast(malloc(3 * sizeof(QuantArg *))); + // per-tensor init + for (int j = 0; j < 3; ++j) { + quant_args_[j] = reinterpret_cast(malloc(sizeof(QuantArg))); + } + auto input_tensor = inputs_.at(kInputIndex); + auto weight_tensor = inputs_.at(kWeightIndex); + auto output_tensor = outputs_.at(kOutputIndex); + auto input_quant_arg = input_tensor->GetQuantParams().front(); + auto weight_quant_arg = weight_tensor->GetQuantParams().front(); + auto output_quant_arg = output_tensor->GetQuantParams().front(); + // input + quant_args_[0][0].zp_ = input_quant_arg.zeroPoint; + quant_args_[0][0].scale_ = input_quant_arg.scale; + // weight + quant_args_[1][0].zp_ = weight_quant_arg.zeroPoint; + quant_args_[1][0].scale_ = weight_quant_arg.scale; + // output + quant_args_[2][0].zp_ = output_quant_arg.zeroPoint; + quant_args_[2][0].scale_ = output_quant_arg.scale; + + conv_quant_arg_->quant_args_ = quant_args_; + conv_quant_arg_->real_multiplier_ = reinterpret_cast(malloc(sizeof(double))); + conv_quant_arg_->left_shift_ = reinterpret_cast(malloc(sizeof(int32_t))); + conv_quant_arg_->right_shift_ = reinterpret_cast(malloc(sizeof(int32_t))); + conv_quant_arg_->quant_multiplier_ = reinterpret_cast(malloc(sizeof(int32_t))); + conv_quant_arg_->out_act_min_ = reinterpret_cast(malloc(sizeof(int32_t))); + conv_quant_arg_->out_act_max_ = reinterpret_cast(malloc(sizeof(int32_t))); + + double real_multiplier = weight_quant_arg.scale * input_quant_arg.scale / output_quant_arg.scale; + conv_quant_arg_->real_multiplier_[0] = real_multiplier; + QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[0], &conv_quant_arg_->left_shift_[0], + &conv_quant_arg_->right_shift_[0]); + + conv_param_->conv_quant_arg_ = *conv_quant_arg_; + ComputeQuantOutRange(conv_param_); + return RET_OK; +} + +void ComputeQuantOutRange(ConvParameter *conv_param) { + int32_t min = std::numeric_limits::min(); + int32_t max = std::numeric_limits::max(); + float scale = conv_param->conv_quant_arg_.quant_args_[2][0].scale_; + int32_t zp = conv_param->conv_quant_arg_.quant_args_[2][0].zp_; + bool is_relu = conv_param->is_relu_; + bool is_relu6 = conv_param->is_relu6_; + int32_t quantized_zero = QuantizeToInt8(0, scale, zp); + int32_t quantized_six = QuantizeToInt8(6, scale, zp); + if (is_relu) { + min = min > quantized_zero ? min : quantized_zero; + } else if (is_relu6) { + min = min > quantized_zero ? min : quantized_zero; + max = max < quantized_six ? max : quantized_six; + } else { + // do nothing + } + conv_param->conv_quant_arg_.out_act_min_[0] = min; + conv_param->conv_quant_arg_.out_act_max_[0] = max; +} + +void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param, + InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func) { + if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && + conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) { + *output_unit = SelectOutputUnit(conv_param); + if (*output_unit > 1) { + *use_winograd = true; + int input_unit = conv_param->kernel_h_ + *output_unit - 1; + input_trans_func = GetInputTransFunc(input_unit); + if (input_trans_func == nullptr) { + MS_LOG(INFO) << "No matching input trans func. Turn back to common conv."; + *use_winograd = false; + } + output_trans_func = GetOutputTransFunc(input_unit, *output_unit); + if (output_trans_func == nullptr) { + MS_LOG(INFO) << "No matching output trans func. Turn back to common conv."; + *use_winograd = false; + } + } else { + *use_winograd = false; + } + } else { + *use_winograd = false; + } +} + +kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx) { + auto conv_param = reinterpret_cast(opParameter); + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + conv_param->input_h_ = inputs.front()->Height(); + conv_param->input_w_ = inputs.front()->Width(); + conv_param->output_h_ = outputs.front()->Height(); + conv_param->output_w_ = outputs.front()->Width(); + bool use_winograd; + int out_unit; + InputTransformUnitFunc input_trans_func = nullptr; + OutputTransformUnitFunc output_trans_func = nullptr; + CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func); + + if (kernel_h == 1 && kernel_w == 1) { + auto kernel = new (std::nothrow) Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx); + return kernel; + } else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { + auto kernel = new (std::nothrow) Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx); + return kernel; + } else if (use_winograd) { + auto kernel = new (std::nothrow) ConvolutionWinogradCPUKernel(opParameter, inputs, outputs, ctx, out_unit); + return kernel; + } else { + auto kernel = new (std::nothrow) ConvolutionCPUKernel(opParameter, inputs, outputs, ctx); + return kernel; + } +} + +#ifdef ENABLE_FP16 +kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx) { + auto conv_param = reinterpret_cast(opParameter); + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + + if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { + auto kernel = new (std::nothrow) Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx); + return kernel; + } else { + auto kernel = new (std::nothrow) ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx); + return kernel; + } +} +#endif + +kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx) { + auto conv_param = reinterpret_cast(opParameter); + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + + if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { + auto kernel = new (std::nothrow) Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx); + return kernel; + } else { + auto kernel = new (std::nothrow) ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx); + return kernel; + } +} + +kernel::LiteKernel *CpuConvKernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter, + const lite::Context *ctx, const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); + auto input_tensor = inputs.at(kInputIndex); + auto data_type = input_tensor->data_type(); + kernel::LiteKernel *kernel = nullptr; + switch (data_type) { + case kNumberTypeInt8: + break; + case kNumberTypeUInt8: + kernel = CpuConvInt8KernelCreator(inputs, outputs, opParameter, ctx); + break; +#ifdef ENABLE_FP16 + case kNumberTypeFloat16: + kernel = CpuConvFp16KernelCreator(inputs, outputs, opParameter, ctx); + break; +#endif + case kNumberTypeFloat32: + kernel = CpuConvFp32KernelCreator(inputs, outputs, opParameter, ctx); + break; + default: + break; + } + + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx) { + auto kernel = new (std::nothrow) ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx) { + auto kernel = new (std::nothrow) ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuConvDwKernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter, + const lite::Context *ctx, const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); + auto input_tensor = inputs.at(kInputIndex); + auto data_type = input_tensor->data_type(); + kernel::LiteKernel *kernel = nullptr; + switch (data_type) { + case kNumberTypeInt8: + kernel = CpuConvDwInt8KernelCreator(inputs, outputs, opParameter, ctx); + break; + case kNumberTypeUInt8: + break; +#ifdef ENABLE_FP16 + case kNumberTypeFloat16: + break; +#endif + case kNumberTypeFloat32: + kernel = CpuConvDwFp32KernelCreator(inputs, outputs, opParameter, ctx); + break; + default: + break; + } + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx) { + auto kernel = new (std::nothrow) DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuDeconvDwInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx) { + auto kernel = new (std::nothrow) DeconvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuDeconvDwKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); + auto input_tensor = inputs.at(kInputIndex); + auto data_type = input_tensor->data_type(); + kernel::LiteKernel *kernel = nullptr; + switch (data_type) { + case kNumberTypeInt8: + kernel = CpuDeconvDwInt8KernelCreator(inputs, outputs, opParameter, ctx); + break; + case kNumberTypeFloat32: + kernel = CpuDeconvDwFp32KernelCreator(inputs, outputs, opParameter, ctx); + break; + default: + break; + } + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx) { + auto kernel = new (std::nothrow) DeConvolutionCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuDeConvInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx) { + auto kernel = new (std::nothrow) DeConvInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuDeConvKernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter, + const lite::Context *ctx, const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); + auto input_tensor = inputs.at(kInputIndex); + auto data_type = input_tensor->data_type(); + kernel::LiteKernel *kernel = nullptr; + switch (data_type) { + case kNumberTypeInt8: + break; + case kNumberTypeUInt8: + kernel = CpuDeConvInt8KernelCreator(inputs, outputs, opParameter, ctx); + break; +#ifdef ENABLE_FP16 + case kNumberTypeFloat16: + break; +#endif + case kNumberTypeFloat32: + kernel = CpuDeConvFp32KernelCreator(inputs, outputs, opParameter, ctx); + break; + default: + break; + } + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Conv2D, CpuConvKernelCreator) +REG_KERNEL(kCPU, PrimitiveType_DeConv2D, CpuDeConvKernelCreator) +REG_KERNEL(kCPU, PrimitiveType_DepthwiseConv2D, CpuConvDwKernelCreator) +REG_KERNEL(kCPU, PrimitiveType_DeDepthwiseConv2D, CpuDeconvDwKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h new file mode 100644 index 00000000000..b5bb0de622d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONVOLUTION_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONVOLUTION_BASE_H_ + +#include +#include +#include +#include +#ifdef ENABLE_ARM +#include +#include +#endif +#include "src/lite_kernel.h" + +#include "include/context.h" + +#include "src/runtime/kernel/arm/base/layout_transform.h" + +using mindspore::lite::Context; +using mindspore::schema::PadMode; +using mindspore::schema::QuantType; + +namespace mindspore::kernel { +class ConvolutionBaseCPUKernel : public LiteKernel { + public: + ConvolutionBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) { + opParameter->thread_num_ = ctx->threadNum; + conv_param_ = reinterpret_cast(opParameter); + } + ~ConvolutionBaseCPUKernel() override; + + int Init() override; + int ReSize() override { return 0; } + int Run() override { return 0; } + virtual int CheckLayout(lite::tensor::Tensor *input_tensor); + int SetQuantParam(); + void FreeQuantParam(); + + protected: + int thread_count_; + int tile_num_; + void *bias_data_ = nullptr; + void *nhwc4_input_; + const Context *ctx_; + ConvParameter *conv_param_; + ConvQuantArg *conv_quant_arg_ = nullptr; + QuantArg **quant_args_ = nullptr; + LayoutConvertor convert_func_; +}; +void ComputeQuantOutRange(ConvParameter *conv_param); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONVOLUTION_BASE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc new file mode 100644 index 00000000000..29fc6ec95b4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/base/fullconnection_base.h" +#include "src/runtime/kernel/arm/int8/fullconnection_int8.h" +#include "src/runtime/kernel/arm/fp32/fullconnection.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "include/context.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_FullConnection; + +namespace mindspore::kernel { +int FullconnectionBaseCPUKernel::Init() { + fc_param_->op_parameter_.thread_num_ = thread_count_; + return RET_OK; +} + +kernel::LiteKernel *CpuFullConnectionKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Concat); + auto input_tensor = inputs.at(kInputIndex); + auto data_type = input_tensor->data_type(); + kernel::LiteKernel *kernel = nullptr; + switch (data_type) { + case kNumberTypeInt8: + case kNumberTypeUInt8: { + kernel = new (std::nothrow) FullconnectionInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (!kernel) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + break; + } + + case kNumberTypeFloat32: { + kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx); + if (!kernel) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + break; + } + + default: + break; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_FullConnection, CpuFullConnectionKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.h b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.h new file mode 100644 index 00000000000..a02368c34ca --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_FULLCONNECTION_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_FULLCONNECTION_BASE_H_ + +#include +#include "src/lite_kernel.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/opclib/matmul.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class FullconnectionBaseCPUKernel : public LiteKernel { + public: + FullconnectionBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) { + fc_param_ = reinterpret_cast(opParameter); + } + ~FullconnectionBaseCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override { return 0; } + + protected: + MatMulParameter *fc_param_; + int thread_count_; + int thread_stride_; + const Context *ctx_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_FULLCONNECTION_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.cc b/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.cc new file mode 100644 index 00000000000..c495f19b0d8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/base/layout_transform.h" + +using mindspore::schema::Format; +namespace mindspore::kernel { +#ifdef ENABLE_FP16 +LayoutConvertor LayoutTransformFp16(schema::Format src_format, schema::Format dst_format) { + // todo + return nullptr; +} +#endif +LayoutConvertor LayoutTransformFp32(schema::Format src_format, schema::Format dst_format) { + // todo + if (src_format == schema::Format_NHWC && dst_format == schema::Format_NC4HW4) { + return PackNHWCToNC4HW4Fp32; + } else if (src_format == schema::Format_NHWC && dst_format == schema::Format_NHWC4) { + return PackNHWCToNHWC4Fp32; + } else if (src_format == schema::Format_NC4HW4 && dst_format == schema::Format_NHWC4) { + return PackNC4HW4ToNHWC4Fp32; + } else if (src_format == schema::Format_NCHW && dst_format == schema::Format_NC4HW4) { + return PackNCHWToNC4HW4Fp32; + } else if (src_format == schema::Format_NC4HW4 && dst_format == schema::Format_NHWC) { + return PackNC4HW4ToNHWCFp32; + } else { + MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(src_format) << " to " + << schema::EnumNameFormat(dst_format); + return nullptr; + } +} + +LayoutConvertor LayoutTransformInt8(schema::Format src_format, schema::Format dst_format) { + // todo + if (src_format == schema::Format_NHWC && dst_format == schema::Format_NHWC4) { + return PackNHWCToNHWC4Int8; + } else { + return nullptr; + } +} + +LayoutConvertor LayoutTransform(TypeId data_type, schema::Format src_format, schema::Format dst_format) { + // todo + switch (data_type) { + case kNumberTypeInt8: + return LayoutTransformInt8(src_format, dst_format); +#ifdef ENABLE_FP16 + case kNumberTypeFloat16: + return LayoutTransformFp16(src_format, dst_format); +#endif + case kNumberTypeFloat32: + return LayoutTransformFp32(src_format, dst_format); + default: + return nullptr; + } +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.h b/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.h new file mode 100644 index 00000000000..b09e533bbb6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/layout_transform.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_LAYOUT_TRANSFORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_LAYOUT_TRANSFORM_H_ + +#ifdef ENABLE_FP16 +#include +#endif +#include "src/runtime/kernel/arm/opclib/pack.h" +#include "src/ir/tensor.h" + +namespace mindspore::kernel { +typedef void (*LayoutConvertor)(const void *src, void *dst, int batch, int plane, int channel); +#ifdef ENABLE_FP16 +LayoutConvertor LayoutTransformFp16(schema::Format src_format, schema::Format dst_format); +#endif + +LayoutConvertor LayoutTransformFp32(schema::Format src_format, schema::Format dst_format); + +LayoutConvertor LayoutTransformInt8(schema::Format src_format, schema::Format dst_format); + +LayoutConvertor LayoutTransform(TypeId data_type, schema::Format src_format, schema::Format dst_format); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_LAYOUT_TRANSFORM_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/base/matrix.cc b/mindspore/lite/src/runtime/kernel/arm/base/matrix.cc new file mode 100644 index 00000000000..a979c70d67f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/matrix.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/base/matrix.h" +#include "utils/log_adapter.h" + +namespace mindspore::kernel { +Matrix *TransformMatrixGenerator(int m, int k) { + auto matrix = new Matrix; + auto aa = malloc(m * k * sizeof(float)); + matrix->SetData(aa); + matrix->SetNum(m, k); +// matrix->data_ = malloc(m * k * sizeof(float)); +// matrix->m_ = m; +// matrix->k_ = k; +// matrix->row_major_ = true; + return matrix; +} + +void ChooseMatrixG(Matrix *matrix_g, Matrix *matrix_gt) { + int m = matrix_g->GetM(); + int k = matrix_g->GetK(); + auto matrix_g_data = reinterpret_cast(matrix_g->GetData()); + auto matrix_gt_data = reinterpret_cast(matrix_gt->GetData()); + // m represents input unit, only 4 or 8 can be accepted for input unit. + // k represents kernel unit, varies from 2 to 7. + if (m == 4 && k == 2) { + MatrixG4x2(matrix_g_data); + MatrixGT2x4(matrix_gt_data); + } else if (m == 8 && k == 2) { + MatrixG8x2(matrix_g_data); + MatrixGT2x8(matrix_gt_data); + } else if (m == 8 && k == 3) { + MatrixG8x3(matrix_g_data); + MatrixGT3x8(matrix_gt_data); + } else if (m == 8 && k == 4) { + MatrixG8x4(matrix_g_data); + MatrixGT4x8(matrix_gt_data); + } else if (m == 8 && k == 5) { + MatrixG8x5(matrix_g_data); + MatrixGT5x8(matrix_gt_data); + } else if (m == 8 && k == 6) { + MatrixG8x6(matrix_g_data); + MatrixGT6x8(matrix_gt_data); + } else if (m == 8 && k == 7) { + MatrixG8x7(matrix_g_data); + MatrixGT7x8(matrix_gt_data); + } else { + MS_LOG(ERROR) << "Unsupported input unit or kernel unit."; + return; + } +} + +void MatrixMultiply(const float *matrix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, bool row) { + // row-major implementation + int count = 0; + for (int h = 0; h < m; h++) { + int h_offset = h * k; + for (int w = 0; w < n; w++) { + float res = 0; + for (int i = 0; i < k; i++) { + res += *(matrix_a + h_offset + i) * *(matrix_b + w + i * n); + } + *(matrix_c + count) = res; + count++; + } + } +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/base/matrix.h b/mindspore/lite/src/runtime/kernel/arm/base/matrix.h new file mode 100644 index 00000000000..d3b49393a40 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/matrix.h @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MATRIX_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MATRIX_H_ + +#include +#include +#include "src/runtime/kernel/arm/opclib/winograd_utils.h" + +namespace mindspore::kernel { +class Matrix { + public: + Matrix() = default; + ~Matrix() { + if (data_ != nullptr) { + free(data_); + } + } + + void SetData(void *data) { this->data_ = data; } + + void *GetData() { return this->data_; } + + void SetNDim(int dim) { this->n_dim_ = dim; } + + int GetNDim() { return this->n_dim_; } + + void SetShape(std::vector shape) { this->shape_ = shape; } + + std::vector GetShape() { return this->shape_; } + + void SetStride(std::vector stride) { this->stride_ = stride; } + + std::vector GetStride() { return this->stride_; } + + void SetNum(int m, int k) { + this->m_ = m; + this->k_ = k; + } + + int GetM() { return this->m_; } + + int GetK() { return this->k_; } + + protected: + void *data_; + std::vector shape_; + std::vector stride_; + int m_; + int k_; + int n_dim_; + bool row_major_; +}; +// struct Matrix { +// void *data_; +// int *shape_; +// int *stride_; +// int m_; +// int k_; +// int n_dim_; +// bool row_major_; +// ~Matrix() { +// if (data_ != nullptr) { +// free(data_); +// } +// if (shape_ != nullptr) { +// free(shape_); +// } +// if (shape_ != nullptr) { +// free(stride_); +// } +// } +//}; + +Matrix *TransformMatrixGenerator(int m, int k); + +// Chinese Remainder Theorem interp: 0.5 +void ChooseMatrixG(Matrix *matrix_g, Matrix *matrix_gt); + +void MatrixMultiply(const float *matrix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, bool row); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MATRIX_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc new file mode 100644 index 00000000000..bdeef575a07 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc @@ -0,0 +1,150 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/base/pooling_base.h" +#include +#include "src/runtime/kernel/arm/int8/pooling_int8.h" +#include "src/runtime/kernel/arm/fp32/pooling.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "include/context.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Pooling; + +namespace mindspore::kernel { +int PoolingBaseCPUKernel::SetQuantParam() { + // per tensor init + pooling_quant_arg_ = reinterpret_cast(malloc(2 * sizeof(QuantArg *))); + pooling_quant_arg_[0] = reinterpret_cast(malloc(sizeof(QuantArg))); + pooling_quant_arg_[1] = reinterpret_cast(malloc(sizeof(QuantArg))); + auto *input_tensor = inputs_.at(kInputIndex); + auto in_quant_arg = input_tensor->GetQuantParams(); + auto *out_tensor = outputs_.at(kOutputIndex); + auto out_quant_arg = out_tensor->GetQuantParams(); + if (in_quant_arg.front().scale != out_quant_arg.front().scale || + in_quant_arg.front().zeroPoint != out_quant_arg.front().zeroPoint) { + MS_LOG(ERROR) << "Scale/ZeroPoint of output must be equal to input's"; + return RET_ERROR; + } + pooling_quant_arg_[0][0].scale_ = in_quant_arg.front().scale; + pooling_quant_arg_[0][0].zp_ = in_quant_arg.front().zeroPoint; + pooling_quant_arg_[1][0].scale_ = out_quant_arg.front().scale; + pooling_quant_arg_[1][0].zp_ = out_quant_arg.front().zeroPoint; + return RET_OK; +} + +void PoolingBaseCPUKernel::FreeQuantParam() { + if (pooling_quant_arg_ != nullptr) { + for (int i = 0; i < 2; ++i) { + if (*(pooling_quant_arg_ + i) != nullptr) { + free(*(pooling_quant_arg_ + i)); + } + } + } +} + +int PoolingBaseCPUKernel::Init() { + MS_ASSERT(inputs_.size() == 1); + MS_ASSERT(outputs_.size() == 1); + pooling_param_->thread_num_ = thread_count_; + MS_ASSERT(this->opParameter != nullptr); + auto in_tensor = this->inputs_.front(); + auto out_tensor = this->outputs_.front(); + MS_ASSERT(in_tensor != nullptr); + MS_ASSERT(out_tensor != nullptr); + pooling_param_->input_batch_ = in_tensor->Batch(); + pooling_param_->input_channel_ = in_tensor->Channel(); + pooling_param_->input_h_ = in_tensor->Height(); + pooling_param_->input_w_ = in_tensor->Width(); + pooling_param_->output_batch_ = out_tensor->Batch(); + pooling_param_->output_channel_ = out_tensor->Channel(); + pooling_param_->output_h_ = out_tensor->Height(); + pooling_param_->output_w_ = out_tensor->Width(); + return RET_OK; +} + +kernel::LiteKernel *CpuPoolingInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Pooling); + auto *kernel = new (std::nothrow) PoolingInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new PoolingInt8CPUKernel fail!"; + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuPoolingFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Pooling); + auto *kernel = new (std::nothrow) PoolingCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new PoolingCPUKernel fail!"; + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuPoolingKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Pooing); + auto input_tensor = inputs.at(kInputIndex); + auto data_type = input_tensor->data_type(); + kernel::LiteKernel *kernel = nullptr; + switch (data_type) { + case kNumberTypeInt8: + case kNumberTypeUInt8: + kernel = CpuPoolingInt8KernelCreator(inputs, outputs, opParameter, ctx); + break; + case kNumberTypeFloat32: + kernel = CpuPoolingFp32KernelCreator(inputs, outputs, opParameter, ctx); + break; + default: + break; + } + + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Pooling, CpuPoolingKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.h b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.h new file mode 100644 index 00000000000..046dd15ba90 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_POOLING_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_POOLING_BASE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/fp32/pooling.h" +#include "include/errorcode.h" + +using mindspore::lite::Context; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +namespace mindspore::kernel { +class PoolingBaseCPUKernel : public LiteKernel { + public: + PoolingBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) { + pooling_param_ = reinterpret_cast(opParameter); + } + ~PoolingBaseCPUKernel() = default; + + int Init() override; + int ReSize() override { return RET_OK; } + int Run() override { return RET_OK; } + int SetQuantParam(); + void FreeQuantParam(); + + protected: + int thread_count_; + const Context *ctx_; + PoolingParameter *pooling_param_; + QuantArg **pooling_quant_arg_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_POOLING_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc new file mode 100644 index 00000000000..f3c7d79a425 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/base/reshape_base.h" +#include +#include "src/runtime/kernel/arm/int8/reshape_int8.h" +#include "src/runtime/kernel/arm/fp32/reshape.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "include/context.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Reshape; + +namespace mindspore::kernel { +int ReshapeBaseCPUKernel::Init() { + reshape_param_->thread_count_ = thread_count_; + return RET_OK; +} + +kernel::LiteKernel *CpuReshapeInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Concat); + auto *kernel = new(std::nothrow) ReshapeInt8CPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuReshapeFp32OrInt32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Concat); + auto *kernel = new(std::nothrow) ReshapeCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; + return nullptr; + } + return kernel; +} + +kernel::LiteKernel *CpuReshapeKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Concat); + auto input_tensor = inputs.at(kInputIndex); + auto data_type = input_tensor->data_type(); + kernel::LiteKernel *kernel = nullptr; + switch (data_type) { + case kNumberTypeInt8: + case kNumberTypeUInt8: + kernel = CpuReshapeInt8KernelCreator(inputs, outputs, opParameter, ctx); + break; + case kNumberTypeInt32: + case kNumberTypeFloat32: + kernel = CpuReshapeFp32OrInt32KernelCreator(inputs, outputs, opParameter, ctx); + break; + default: + break; + } + + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Reshape, CpuReshapeKernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.h b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.h new file mode 100644 index 00000000000..48ae7ee3004 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_RESHAPE_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_RESHAPE_BASE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/reshape_parameter.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class ReshapeBaseCPUKernel : public LiteKernel { + public: + ReshapeBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) { + reshape_param_ = reinterpret_cast(opParameter); + } + ~ReshapeBaseCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override { return 0; } + + protected: + int thread_count_; + const Context *ctx_; + ReshapeParameter *reshape_param_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_RESHAPE_BASE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc new file mode 100644 index 00000000000..600337cfb1b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc @@ -0,0 +1,252 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" +#include "src/runtime/kernel/arm/opclib/fp16/conv_fp16.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { +void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvParameter *conv_param) { + auto input_channel = conv_param->input_channel_; + auto output_channel = conv_param->output_channel_; + auto kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + int iC4 = UP_DIV(input_channel, C4NUM); + int oC8 = UP_DIV(output_channel, C8NUM); + + size_t tmp_size = oC8 * C8NUM * iC4 * C4NUM * kernel_plane * sizeof(float16_t); + auto tmp_addr = reinterpret_cast(malloc(tmp_size)); + memset(tmp_addr, 0, tmp_size); + + PackWeightToC4Fp16(origin_weight, tmp_addr, conv_param); + Conv3x3Fp16FilterTransform(tmp_addr, dst_weight, iC4, output_channel, kernel_plane); + + free(tmp_addr); +} + +int Convolution3x3FP16CPUKernel::InitWeightBias() { + auto input_channel = conv_param_->input_channel_; + int output_channel = conv_param_->output_channel_; + int iC4 = UP_DIV(input_channel, C4NUM); + int oC8 = UP_DIV(output_channel, C8NUM); + // init weight + size_t transformed_size = iC4 * C8NUM * oC8 * C8NUM * 36 * sizeof(float16_t); + transformed_filter_addr_ = reinterpret_cast(malloc(transformed_size)); + if (transformed_filter_addr_ == nullptr) { + MS_LOG(ERROR) << "malloc transformed_filter_addr_ failed."; + return RET_ERROR; + } + memset(transformed_filter_addr_, 0, transformed_size); + float *origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t); + fp16_weight_ = malloc(fp16_weight_size); + if (fp16_weight_ == nullptr) { + MS_LOG(ERROR) << "malloc fp16_weight_ failed."; + return RET_ERROR; + } + memset(fp16_weight_, 0, fp16_weight_size); + for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) { + fp16_weight_[i] = (float16_t)origin_weight[i]; + } + ProcessFilterFp16(fp16_weight_, transformed_filter_addr_, conv_param_); + + // init bias + size_t new_bias_size = oC8 * C8NUM * sizeof(float16_t); + bias_data_ = reinterpret_cast(malloc(new_bias_size)); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "malloc bias_data_ failed."; + return RET_ERROR; + } + memset(bias_data_, 0, new_bias_size); + if (inputs_.size() == kInputSize2) { + auto ori_bias_addr = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + for (int i = 0; i < out_channel; ++i) { + bias_data_[i] = (float16_t)ori_bias_addr[i]; + } + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + return RET_OK; +} + +int Convolution3x3FP16CPUKernel::InitTmpBuffer() { + int tile_num = 16; + int k_plane = 36; + int iC4 = UP_DIV(conv_param_->input_channel_, C4NUM); + int oC8 = UP_DIV(conv_param_->output_channel_, C8NUM); + size_t tile_buffer_size = thread_count_ * tile_num * k_plane * iC4 * C4NUM * sizeof(float16_t); + tile_buffer_ = reinterpret_cast(malloc(tile_buffer_size)); + if (tile_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc tile_buffer_ failed."; + return RET_ERROR; + } + memset(tile_buffer_, 0, tile_buffer_size); + + size_t block_unit_buffer_size = thread_count_ * k_plane * C4NUM * sizeof(float16_t); + block_unit_buffer_ = reinterpret_cast(malloc(block_unit_buffer_size)); + if (block_unit_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc block_unit_buffer_ failed."; + return RET_ERROR; + } + memset(block_unit_buffer_, 0, block_unit_buffer_size); + + size_t tmp_dst_buffer_size = thread_count_ * tile_num * k_plane * oC8 * C8NUM * sizeof(float16_t); + tmp_dst_buffer_ = reinterpret_cast(malloc(tmp_dst_buffer_size)); + if (tmp_dst_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_dst_buffer_ failed."; + return RET_ERROR; + } + memset(tmp_dst_buffer_, 0, tmp_dst_buffer_size); + + size_t tmp_out_size = oC8 * C8NUM * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * + tile_num * sizeof(float16_t); + tmp_out_ = reinterpret_cast(malloc(tmp_out_size)); + if (tmp_out_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_out_ failed."; + return RET_ERROR; + } + memset(tmp_out_, 0, tmp_out_size); + + size_t fp16_input_size = + in_channel * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); + fp16_input_ = malloc(fp16_input_size); + if (fp16_input_ == nullptr) { + MS_LOG(ERROR) << "malloc fp16_input_ failed."; + return RET_ERROR; + } + memset(fp16_input_, 0, fp16_input_size); + + + // init nhwc4 input + size_t nhwc4_input_size = + iC4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); + nhwc4_input_ = malloc(nhwc4_input_size); + if (nhwc4_input_ == nullptr) { + MS_LOG(ERROR) << "malloc nhwc4_input_ failed."; + return RET_ERROR; + } + memset(nhwc4_input_, 0, nhwc4_input_size); + return RET_OK; +} + +void Convolution3x3FP16CPUKernel::ConfigInputOutput() { + auto output_tensor = outputs_.at(kOutputIndex); + output_tensor->SetFormat(schema::Format_NHWC); + auto input_tensor = inputs_.at(kInputIndex); + auto ret = CheckLayout(input_tensor); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Check layout failed."; + return; + } +} + +int Convolution3x3FP16CPUKernel::Init() { + ConvolutionBaseCPUKernel::Init(); + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + ConfigInputOutput(); + return RET_OK; +} + +int Convolution3x3FP16CPUKernel::ReSize() { + if (tile_buffer_ != nullptr) { + free(tile_buffer_); + } + if (block_unit_buffer_ != nullptr) { + free(block_unit_buffer_); + } + if (tmp_dst_buffer_ != nullptr) { + free(tmp_dst_buffer_); + } + if (tmp_out_ != nullptr) { + free(tmp_out_); + } + + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int Convolution3x3FP16CPUKernel::RunImpl(int task_id) { + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + Conv3x3Fp16(reinterpret_cast(nhwc4_input_), transformed_filter_addr_, + reinterpret_cast(bias_data_), output_addr, tile_buffer_, block_unit_buffer_, tmp_dst_buffer_, + tmp_out_, task_id, conv_param_); + return RET_OK; +} + +int Convolution3x3Fp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv = reinterpret_cast(cdata); + auto error_code = conv->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Convolution3x3 Fp16 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int Convolution3x3FP16CPUKernel::Run() { + auto input_tensor = inputs_.at(kInputIndex); + auto ori_input_data = reinterpret_cast(input_tensor->Data()); + // cast fp32 input data to fp16 + for (int i = 0; i < input_tensor->ElementsNum(); ++i) { + fp16_input_[i] = (float16_t)ori_input_data[i]; + } + int in_batch = conv_param_->input_batch_; + int in_h = conv_param_->input_h_; + int in_w = conv_param_->input_w_; + int in_channel = conv_param_->input_channel_; + convert_func_(reinterpret_cast(fp16_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel); + + int error_code = LiteBackendParallelLaunch(Convolution3x3Fp16Impl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv3x3 fp16 error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h new file mode 100644 index 00000000000..5cf60d1c1e1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_3x3_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_3x3_FP16_H_ + +#include +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/opclib/winograd_transform.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/opclib/optimized_kernel.h" + +namespace mindspore::kernel { +class Convolution3x3FP16CPUKernel : public ConvolutionBaseCPUKernel { + public: + Convolution3x3FP16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~Convolution3x3FP16CPUKernel() override { + if (fp16_input_ != nullptr) { + free(fp16_input_); + } + if (fp16_weight_ != nullptr) { + free(fp16_weight_); + } + if (transformed_filter_addr_ != nullptr) { + free(transformed_filter_addr_); + } + if (tile_buffer_ != nullptr) { + free(tile_buffer_); + } + if (block_unit_buffer_ != nullptr) { + free(block_unit_buffer_); + } + if (tmp_dst_buffer_ != nullptr) { + free(tmp_dst_buffer_); + } + if (tmp_out_ != nullptr) { + free(tmp_out_); + } + } + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + int InitWeightBias(); + int InitTmpBuffer(); + void ConfigInputOutput(); + + private: + float16_t *fp16_input_; + float16_t *fp16_weight_; + float16_t *transformed_filter_addr_; + float16_t *tile_buffer_; + float16_t *block_unit_buffer_; + float16_t *tmp_dst_buffer_; + float16_t *tmp_out_; +}; +void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvParameter *conv_param); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_3x3_FP16_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc new file mode 100644 index 00000000000..f00a5835e53 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -0,0 +1,221 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp16/convolution_fp16.h" +#include "src/runtime/kernel/arm/opclib/fp16/conv_fp16.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { +int ConvolutionFP16CPUKernel::InitWeightBias() { + int kernel_h = conv_param_->kernel_h_; + int kernel_w = conv_param_->kernel_w_; + int in_channel = conv_param_->input_channel_; + int out_channel = conv_param_->output_channel_; + int oc8 = UP_DIV(out_channel, C8NUM); + int channel_block = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int pack_weight_size = oc8 * channel_block * C8NUM * C4NUM * kernel_plane; + + // init weight + float *origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + size_t fp16_weight_size = in_channel * out_channel * kernel_h * kernel_w * sizeof(float16_t); + fp16_weight_ = malloc(fp16_weight_size); + if (fp16_weight_ == nullptr) { + MS_LOG(ERROR) << "malloc fp16_weight_ failed."; + return RET_ERROR; + } + memset(fp16_weight_, 0, fp16_weight_size); + for (int i = 0; i < fp16_weight_size / sizeof(float16_t); ++i) { + fp16_weight_[i] = (float16_t)origin_weight[i]; + } + packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float16_t))); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "malloc packed_weight_ failed."; + return RET_ERROR; + } + memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t)); + PackWeightFp16(fp16_weight_, conv_param_, packed_weight_); + + // init bias + bias_data_ = reinterpret_cast(malloc(oc8 * C8NUM * sizeof(float16_t))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "malloc bias_data_ failed."; + return RET_ERROR; + } + memset(bias_data_, 0, oc8 * C8NUM * sizeof(float16_t)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + for (int i = 0; i < out_channel; ++i) { + bias_data_[i] = (float16_t)ori_bias[i]; + } + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + return RET_OK; +} + +int ConvolutionFP16CPUKernel::InitTmpBuffer() { + int kernel_h = conv_param_->kernel_h_; + int kernel_w = conv_param_->kernel_w_; + int in_batch = conv_param_->input_batch_; + int in_channel = conv_param_->input_channel_; + int out_channel = conv_param_->output_channel_; + int channel_block = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + + // malloc packed_inputs + int cal_num = 16; + int output_count = conv_param_->output_h_ * conv_param_->output_w_; + int output_tile_count = UP_DIV(output_count, cal_num); + int unit_size = kernel_plane * channel_block * C4NUM; + int packed_input_size = output_tile_count * cal_num * unit_size; + packed_input_ = reinterpret_cast(malloc(in_batch * packed_input_size * sizeof(float16_t))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "malloc packed_input_ failed."; + return RET_ERROR; + } + memset(packed_input_, 0, in_batch * packed_input_size * sizeof(float16_t)); + + size_t fp16_input_size = + in_channel * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); + fp16_input_ = malloc(fp16_input_size); + if (fp16_input_ == nullptr) { + MS_LOG(ERROR) << "malloc fp16_input_ failed."; + return RET_ERROR; + } + memset(fp16_input_, 0, fp16_input_size); + + size_t nhwc4_input_size = channel_block * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * + conv_param_->input_w_ * sizeof(float16_t); + nhwc4_input_ = malloc(nhwc4_input_size); + if (nhwc4_input_ == nullptr) { + MS_LOG(ERROR) << "malloc nhwc4_input_ failed."; + return RET_ERROR; + } + memset(nhwc4_input_, 0, nhwc4_input_size); + + tmp_output_block_ = reinterpret_cast(malloc(cal_num * out_channel * sizeof(float16_t))); + if (tmp_output_block_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_output_block_ failed."; + return RET_ERROR; + } + return RET_OK; +} + +void ConvolutionFP16CPUKernel::ConfigInputOutput() { + auto input_tensor = inputs_.at(kInputIndex); + auto ret = CheckLayout(input_tensor); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Check layout failed."; + return; + } + auto output_tensor = outputs_.at(kOutputIndex); + output_tensor->SetFormat(schema::Format_NHWC); +} + +int ConvolutionFP16CPUKernel::Init() { + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + ConfigInputOutput(); + return RET_OK; +} + +int ConvolutionFP16CPUKernel::ReSize() { + if (packed_input_ != nullptr) { + free(packed_input_); + } + if (tmp_output_block_ != nullptr) { + free(tmp_output_block_); + } + if (nhwc4_input_ != nullptr) { + free(nhwc4_input_); + } + + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionFP16CPUKernel::RunImpl(int task_id) { + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + ConvFp16(reinterpret_cast(nhwc4_input_), packed_input_, packed_weight_, + reinterpret_cast(bias_data_), tmp_output_block_, output_addr, task_id, conv_param_); + return RET_OK; +} + +int ConvolutionFp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv = reinterpret_cast(cdata); + auto error_code = conv->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "ConvolutionFp16 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionFP16CPUKernel::Run() { + auto input_tensor = inputs_.at(kInputIndex); + auto ori_input_data = reinterpret_cast(input_tensor->Data()); + // cast fp32 input data to fp16 + for (int i = 0; i < input_tensor->ElementsNum(); ++i) { + fp16_input_[i] = (float16_t)ori_input_data[i]; + } + int in_batch = conv_param_->input_batch_; + int in_h = conv_param_->input_h_; + int in_w = conv_param_->input_w_; + int in_channel = conv_param_->input_channel_; + convert_func_(reinterpret_cast(fp16_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel); + + int error_code = LiteBackendParallelLaunch(ConvolutionFp16Impl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv fp16 error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h new file mode 100644 index 00000000000..bf7540f6fcb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_FP16_H_ + +#include +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/opclib/optimized_kernel.h" + +namespace mindspore::kernel { +typedef void (*FP16_GEMM_FUNC)(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, + size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, + size_t relu6); + +class ConvolutionFP16CPUKernel : public ConvolutionBaseCPUKernel { + public: + ConvolutionFP16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~ConvolutionFP16CPUKernel() override { + if (fp16_input_ != nullptr) { + free(fp16_input_); + } + if (fp16_weight_ != nullptr) { + free(fp16_weight_); + } + if (packed_input_ != nullptr) { + free(packed_input_); + } + if (packed_weight_ != nullptr) { + free(packed_weight_); + } + if (tmp_output_block_ != nullptr) { + free(tmp_output_block_); + } + } + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + int InitWeightBias(); + int InitTmpBuffer(); + void ConfigInputOutput(); + + private: + bool support_fp16_ = true; + float16_t *fp16_input_; + float16_t *fp16_weight_; + float16_t *packed_input_; + float16_t *packed_weight_; + float16_t *tmp_output_block_; + FP16_GEMM_FUNC gemm_func_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_FP16_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc new file mode 100644 index 00000000000..b12b5ce5656 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/activation.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::ActivationType_HSWISH; +using mindspore::schema::ActivationType_LEAKY_RELU; +using mindspore::schema::ActivationType_RELU; +using mindspore::schema::ActivationType_RELU6; +using mindspore::schema::PrimitiveType_Activation; + +namespace mindspore::kernel { +int ActivationCPUKernel::Init() { return RET_OK; } + +int ActivationCPUKernel::ReSize() { return RET_OK; } + +int ActivationCPUKernel::DoActivation(int task_id) { + auto input_addr = reinterpret_cast(inputs_.at(0)->Data()); + auto output_addr = reinterpret_cast(outputs_.at(0)->Data()); + auto length = inputs_.at(0)->ElementsNum(); + + int stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); + + auto error_code = RET_OK; + + if (type_ == schema::ActivationType_RELU) { + error_code = Relu(input_addr + stride * task_id, count, output_addr + stride * task_id); + } else if (type_ == schema::ActivationType_RELU6) { + error_code = Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id); + } else if (type_ == schema::ActivationType_LEAKY_RELU) { + error_code = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_); + } else if (type_ == schema::ActivationType_SIGMOID) { + error_code = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id); + } else if (type_ == schema::ActivationType_TANH) { + error_code = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id); + } else if (type_ == schema::ActivationType_HSWISH) { + error_code = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id); + } else { + MS_LOG(ERROR) << "Activation type error"; + return RET_ERROR; + } + if (error_code != RET_OK) { + return RET_ERROR; + } + return RET_OK; +} + +int ActivationRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto activation_kernel = reinterpret_cast(cdata); + auto error_code = activation_kernel->DoActivation(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "ActivationRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ActivationCPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(ActivationRun, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Activation function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuActivationFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Activation); + auto *kernel = new (std::nothrow) ActivationCPUKernel(opParameter, inputs, outputs, ctx); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Activation, CpuActivationFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/activation.h b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.h new file mode 100644 index 00000000000..db85deec1f8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/fp32/activation.h" + +namespace mindspore::kernel { +class ActivationCPUKernel : public LiteKernel { + public: + ActivationCPUKernel(OpParameter *param, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(param, inputs, outputs), thread_count_(ctx->threadNum) { + type_ = (reinterpret_cast(param))->type_; + alpha_ = (reinterpret_cast(param))->alpha_; + } + ~ActivationCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoActivation(int task_id); + + private: + int thread_count_; + int type_; + float alpha_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ACTIVATION_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/addn.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/addn.cc new file mode 100644 index 00000000000..a17828affd6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/addn.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/addn.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/fp32/arithmetic.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_AddN; + +namespace mindspore::kernel { +namespace { +constexpr int kLeastInputNum = 2; +} + +int AddNCPUKernel::Init() { return RET_OK; } + +int AddNCPUKernel::ReSize() { return RET_OK; } + +int AddNCPUKernel::Run() { + auto input0_data = reinterpret_cast(inputs_[0]->Data()); + auto input1_data = reinterpret_cast(inputs_[1]->Data()); + auto output_data = reinterpret_cast(outputs_[0]->Data()); + auto element_num = inputs_[0]->ElementsNum(); + + ElementAdd(input0_data, input1_data, output_data, element_num); + for (int i = 2; i < inputs_.size(); ++i) { + ElementAdd(reinterpret_cast(inputs_[i]->Data()), output_data, output_data, element_num); + } + return RET_OK; +} + +kernel::LiteKernel *CpuAddNFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) AddNCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new AddNCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed! name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_AddN, CpuAddNFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/addn.h b/mindspore/lite/src/runtime/kernel/arm/fp32/addn.h new file mode 100644 index 00000000000..544ab928af8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/addn.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDN_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDN_H_ + +#include +#include "src/lite_kernel.h" +#include "schema/model_generated.h" + + +namespace mindspore::kernel { +class AddNCPUKernel : public LiteKernel { + public: + AddNCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) : LiteKernel(parameter, inputs, outputs) {} + ~AddNCPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDN_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.cc new file mode 100644 index 00000000000..c15a86105c9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/argminmax.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/opclib/fp32/arg_min_max.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_ArgMax; +using mindspore::schema::PrimitiveType_ArgMin; + +namespace mindspore::kernel { +namespace { +constexpr int kInputNum = 1; +constexpr int kOutputNum = 1; +} // namespace + +int ArgMinMaxCPUKernel::Init() { + switch (opParameter->type_) { + case PrimitiveType_ArgMax: + get_max_ = true; + break; + case PrimitiveType_ArgMin: + get_max_ = false; + break; + default: + MS_LOG(ERROR) << "Unexpected type " << opParameter->type_; + return RET_ERROR; + } + auto dims_size = inputs_.at(0)->shape().size(); + axis_ = reinterpret_cast(opParameter)->axis_; + axis_ = axis_ < 0 ? axis_ + dims_size : axis_; + return RET_OK; +} + +int ArgMinMaxCPUKernel::Run() { + auto input = inputs_.at(0); + + auto input_data = reinterpret_cast(inputs_.at(0)->Data()); + auto output_data = reinterpret_cast(outputs_.at(0)->Data()); + + auto shape = input->shape().data(); + int dims_number = input->shape().size(); + bool out_value = reinterpret_cast(opParameter)->out_value_; + if (get_max_) { + ArgMax(input_data, shape, dims_number, axis_, out_value, output_data); + } else { + ArgMin(input_data, shape, dims_number, axis_, out_value, output_data); + } + return RET_OK; +} + +kernel::LiteKernel *CpuArgMinMaxFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) ArgMinMaxCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ArgMinMaxCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_ArgMax, CpuArgMinMaxFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_ArgMin, CpuArgMinMaxFp32KernelCreator) +} // namespace mindspore::kernel + + + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.h b/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.h new file mode 100644 index 00000000000..683f11f7bed --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARGMINMAX_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARGMINMAX_H_ + +#include +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class ArgMinMaxCPUKernel : public LiteKernel { + public: + ArgMinMaxCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) : LiteKernel(parameter, inputs, outputs) {} + + ~ArgMinMaxCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override; + + private: + int axis_; + bool get_max_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARGMINMAX_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc new file mode 100644 index 00000000000..8129b5f58c3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc @@ -0,0 +1,168 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/arithmetic.h" +#include "src/runtime/kernel/arm/int8/add_int8.h" +#include "src/runtime/kernel/arm/int8/mul_int8.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Eltwise; + +namespace mindspore::kernel { + +ArithmeticCPUKernel::~ArithmeticCPUKernel() { + if (tile_data0_ != nullptr) { + free(tile_data0_); + tile_data0_ = nullptr; + } + if (tile_data1_ != nullptr) { + free(tile_data1_); + tile_data1_ = nullptr; + } +} +int ArithmeticCPUKernel::Init() { + auto element_num = outputs_[0]->ElementsNum(); + + tile_data0_ = new float[element_num]; + tile_data1_ = new float[element_num]; + + return RET_OK; +} + +int ArithmeticCPUKernel::ReSize() { return RET_OK; } + +int ArithmeticCPUKernel::DoArithmetic(int task_id) { + auto input0_data = reinterpret_cast(inputs_[0]->Data()); + auto input1_data1 = reinterpret_cast(inputs_[1]->Data()); + auto output_data = reinterpret_cast(outputs_[0]->Data()); + auto element_num = outputs_[0]->ElementsNum(); + if (arithmeticParameter_->broadcasting_) { + if (arithmetic_broadcast_run_ == nullptr) { + MS_LOG(ERROR) << "broadcasting_run function is nullptr!"; + return RET_ERROR; + } + + MS_ASSERT(thread_count_ != 0); + int stride = UP_DIV(element_num, thread_count_); + int count = MSMIN(stride, element_num - stride * task_id); + + int error_code = arithmetic_run_(tile_data0_ + stride * task_id, tile_data1_ + stride * task_id, + output_data + stride * task_id, count); + + if (error_code != RET_OK) { + return RET_ERROR; + } + } else if (arithmetic_run_ != nullptr) { + int error_code = arithmetic_run_(input0_data, input1_data1, output_data, element_num); + if (error_code != RET_OK) { + return RET_ERROR; + } + } else { + MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; + return RET_ERROR; + } + return RET_OK; +} + +int ArithmeticsRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto arithmetic_kernel = reinterpret_cast(cdata); + auto error_code = arithmetic_kernel->DoArithmetic(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "ArithmeticsRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ArithmeticCPUKernel::Run() { + if (arithmeticParameter_->broadcasting_) { + auto input_data0 = reinterpret_cast(inputs_[0]->Data()); + auto input_data1 = reinterpret_cast(inputs_[1]->Data()); + TileDimensions(input_data0, input_data1, tile_data0_, tile_data1_, arithmeticParameter_); + } + int error_code = LiteBackendParallelLaunch(ArithmeticsRun, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Arithmetic function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuArithmeticFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *parameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(parameter); + MS_ASSERT(inputs.at(0)); + auto data_type = inputs.at(0)->data_type(); + kernel::LiteKernel *kernel = nullptr; + switch (data_type) { + case kNumberTypeFloat32: + kernel = new (std::nothrow) ArithmeticCPUKernel(parameter, inputs, outputs, ctx); + break; + case kNumberTypeInt8: + if (desc.type == schema::PrimitiveType_Add) { + kernel = new (std::nothrow) QuantizedAddCPUKernel(parameter, inputs, outputs, ctx); + } else if (desc.type == schema::PrimitiveType_Mul) { + kernel = new (std::nothrow) MulInt8CPUKernel(parameter, inputs, outputs, ctx); + } else { + } + break; + default: + break; + } + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Mul, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Add, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Div, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_LogicalOr, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Maximum, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Minimum, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_FloorDiv, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_FloorMod, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_SquaredDifference, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Equal, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_NotEqual, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Less, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_LessEqual, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Greater, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_GreaterEqual, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Eltwise, CpuArithmeticFp32KernelCreator) + +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h new file mode 100644 index 00000000000..981e00e3dcb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h @@ -0,0 +1,145 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h" +#include "schema/model_generated.h" + +using mindspore::schema::PrimitiveType_Add; +using mindspore::schema::PrimitiveType_Div; +using mindspore::schema::PrimitiveType_Equal; +using mindspore::schema::PrimitiveType_FloorDiv; +using mindspore::schema::PrimitiveType_FloorMod; +using mindspore::schema::PrimitiveType_Greater; +using mindspore::schema::PrimitiveType_GreaterEqual; +using mindspore::schema::PrimitiveType_Less; +using mindspore::schema::PrimitiveType_LessEqual; +using mindspore::schema::PrimitiveType_LogicalAnd; +using mindspore::schema::PrimitiveType_LogicalOr; +using mindspore::schema::PrimitiveType_Maximum; +using mindspore::schema::PrimitiveType_Minimum; +using mindspore::schema::PrimitiveType_Mul; +using mindspore::schema::PrimitiveType_NotEqual; +using mindspore::schema::PrimitiveType_SquaredDifference; +using mindspore::schema::PrimitiveType_Sub; + +namespace mindspore::kernel { +class ArithmeticCPUKernel : public LiteKernel { + typedef int (*ArithmeticRun)(float *input0, float *input1, float *output, int element_size); + typedef int (*ArithmeticBroadcastRun)(float *input0, float *input1, float *tile_input0, float *tile_input1, + float *output, int element_size, ArithmeticParameter *param); + + public: + ArithmeticCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->threadNum) { + switch (parameter->type_) { + case PrimitiveType_Mul: + arithmetic_run_ = ElementMul; + arithmetic_broadcast_run_ = BroadcastMul; + break; + case PrimitiveType_Add: + arithmetic_run_ = ElementAdd; + arithmetic_broadcast_run_ = BroadcastAdd; + break; + case PrimitiveType_Sub: + arithmetic_run_ = ElementSub; + arithmetic_broadcast_run_ = BroadcastSub; + break; + case PrimitiveType_Div: + arithmetic_run_ = ElementDiv; + arithmetic_broadcast_run_ = BroadcastDiv; + break; + case PrimitiveType_LogicalAnd: + arithmetic_run_ = ElementLogicalAnd; + arithmetic_broadcast_run_ = BroadcastLogicalAnd; + break; + case PrimitiveType_LogicalOr: + arithmetic_run_ = ElementLogicalOr; + arithmetic_broadcast_run_ = BroadcastLogicalOr; + break; + case PrimitiveType_Maximum: + arithmetic_run_ = ElementMaximum; + arithmetic_broadcast_run_ = BroadcastMaximum; + break; + case PrimitiveType_Minimum: + arithmetic_run_ = ElementMinimum; + arithmetic_broadcast_run_ = BroadcastMinimum; + break; + case PrimitiveType_FloorDiv: + arithmetic_run_ = ElementFloorDiv; + arithmetic_broadcast_run_ = BroadcastFloorDiv; + break; + case PrimitiveType_FloorMod: + arithmetic_run_ = ElementFloorMod; + arithmetic_broadcast_run_ = BroadcastFloorMod; + case PrimitiveType_Equal: + arithmetic_run_ = ElementEqual; + arithmetic_broadcast_run_ = BroadcastEqual; + break; + case PrimitiveType_NotEqual: + arithmetic_run_ = ElementNotEqual; + arithmetic_broadcast_run_ = BroadcastNotEqual; + break; + case PrimitiveType_Less: + arithmetic_run_ = ElementEqual; + arithmetic_broadcast_run_ = BroadcastEqual; + break; + case PrimitiveType_LessEqual: + arithmetic_run_ = ElementNotEqual; + arithmetic_broadcast_run_ = BroadcastNotEqual; + break; + case PrimitiveType_Greater: + arithmetic_run_ = ElementGreater; + arithmetic_broadcast_run_ = BroadcastGreater; + break; + case PrimitiveType_GreaterEqual: + arithmetic_run_ = ElementGreaterEqual; + arithmetic_broadcast_run_ = BroadcastGreaterEqual; + break; + case PrimitiveType_SquaredDifference: + arithmetic_run_ = ElementSquaredDifference; + arithmetic_broadcast_run_ = BroadcastSquaredDifference; + break; + default: + MS_LOG(ERROR) << "Error Operator type " << parameter->type_; + arithmetic_run_ = nullptr; + arithmetic_broadcast_run_ = nullptr; + break; + } + arithmeticParameter_ = reinterpret_cast(parameter); + } + ~ArithmeticCPUKernel() override; + + int Init() override; + int ReSize() override; + int Run() override; + int DoArithmetic(int task_id); + + private: + int thread_count_; + float *tile_data0_ = nullptr; + float *tile_data1_ = nullptr; + ArithmeticParameter *arithmeticParameter_; + ArithmeticRun arithmetic_run_; + ArithmeticBroadcastRun arithmetic_broadcast_run_; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc new file mode 100644 index 00000000000..a6bf3bd60a1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc @@ -0,0 +1,116 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/fp32/arithmetic_self.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +int ArithmeticSelfCPUKernel::Init() { + int ret = ReSize(); + return ret; +} + +int ArithmeticSelfCPUKernel::ReSize() { + data_size_ = inputs_[0]->ElementsNum(); + thread_sz_count_ = MSMIN(thread_count_, data_size_); + thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); + return RET_OK; +} + +int ArithmeticSelfRuns(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->DoArithmeticSelf(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ArithmeticSelfRuns error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int ArithmeticSelfCPUKernel::DoArithmeticSelf(int task_id) { + int size = MSMIN(thread_sz_stride_, data_size_ - task_id * thread_sz_stride_); + if (size <= 0) { + return RET_OK; + } + int offset = task_id * thread_sz_stride_; + if (arithmeticSelf_run_) { + auto ret = arithmeticSelf_run_(in_ptr_ + offset, out_ptr_ + offset, size); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Run failed, illegal input! "; + return ret; + } + } else { + MS_LOG(ERROR) << "Run function is null! "; + return RET_ERROR; + } + return RET_OK; +} + +int ArithmeticSelfCPUKernel::Run() { + auto input_tensor = inputs_.at(0); + auto out_tensor = outputs_.at(0); + in_ptr_ = reinterpret_cast(input_tensor->Data()); + out_ptr_ = reinterpret_cast(out_tensor->Data()); + int ret = LiteBackendParallelLaunch(ArithmeticSelfRuns, this, thread_sz_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +kernel::LiteKernel *CpuArithmeticSelfFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Creator failed, opParameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) ArithmeticSelfCPUKernel(opParameter, inputs, outputs, ctx); + MS_ASSERT(kernel != nullptr); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Abs, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Cos, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Exp, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Log, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Square, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Sqrt, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Rsqrt, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Sin, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_LogicalNot, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Floor, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, PrimitiveType_Ceil, CpuArithmeticSelfFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h new file mode 100644 index 00000000000..c6be54dd678 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_SELF_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_SELF_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/fp32/arithmetic_self.h" +#include "schema/model_generated.h" +#include "include/context.h" + + +using mindspore::lite::Context; +using mindspore::schema::PrimitiveType_Abs; +using mindspore::schema::PrimitiveType_Cos; +using mindspore::schema::PrimitiveType_Exp; +using mindspore::schema::PrimitiveType_Floor; +using mindspore::schema::PrimitiveType_Log; +using mindspore::schema::PrimitiveType_LogicalNot; +using mindspore::schema::PrimitiveType_Rsqrt; +using mindspore::schema::PrimitiveType_Sin; +using mindspore::schema::PrimitiveType_Sqrt; +using mindspore::schema::PrimitiveType_Square; +using mindspore::schema::PrimitiveType_Ceil; + +namespace mindspore::kernel { +class ArithmeticSelfCPUKernel : public LiteKernel { + typedef int (*ArithmeticSelfRun)(float *input, float *output, int element_size); + + public: + explicit ArithmeticSelfCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) { + switch (parameter->type_) { + case PrimitiveType_Abs: + arithmeticSelf_run_ = ElementAbs; + break; + case PrimitiveType_Cos: + arithmeticSelf_run_ = ElementCos; + break; + case PrimitiveType_Exp: + arithmeticSelf_run_ = ElementExp; + break; + case PrimitiveType_Log: + arithmeticSelf_run_ = ElementLog; + break; + case PrimitiveType_Square: + arithmeticSelf_run_ = ElementSquare; + break; + case PrimitiveType_Sqrt: + arithmeticSelf_run_ = ElementSqrt; + break; + case PrimitiveType_Rsqrt: + arithmeticSelf_run_ = ElementRsqrt; + break; + case PrimitiveType_Sin: + arithmeticSelf_run_ = ElementSin; + break; + case PrimitiveType_LogicalNot: + arithmeticSelf_run_ = ElementLogicalNot; + break; + case PrimitiveType_Floor: + arithmeticSelf_run_ = ElementFloor; + break; + case PrimitiveType_Ceil: + arithmeticSelf_run_ = ElementCeil; + break; + default: + break; + } + arithmeticSelfParameter_ = reinterpret_cast(parameter); + } + ~ArithmeticSelfCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoArithmeticSelf(int task_id); + + private: + int thread_count_; + int thread_sz_count_; + int thread_sz_stride_; + size_t data_size_; + ArithmeticSelfParameter *arithmeticSelfParameter_; + ArithmeticSelfRun arithmeticSelf_run_; + const Context *ctx_; + float *in_ptr_; + float *out_ptr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_SELF_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.cc new file mode 100644 index 00000000000..b2b6d9fd344 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/batch_to_space.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/opclib/fp32/batch_to_space.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_FORMAT_ERR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_BatchToSpace; + +namespace mindspore::kernel { + +int BatchToSpaceCPUKernel::Init() { + if (inputs_[0]->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "batch_to_space only support NHWC now!"; + return RET_FORMAT_ERR; + } + BatchToSpaceParameter *param = reinterpret_cast(this->opParameter); + for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) { + if (param->crops_[i] != 0) { + no_crop_ = false; + } + } + return RET_OK; +} + +int BatchToSpaceCPUKernel::Run() { + auto input = inputs_[0]; + auto output = outputs_[0]; + const float *input_data = reinterpret_cast(input->Data()); + float *output_data = reinterpret_cast(output->Data()); + auto in_shape = input->shape(); + auto out_shape = output->shape(); + BatchToSpaceParameter *param = reinterpret_cast(this->opParameter); + + if (no_crop_) { + BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_); + } else { + BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, param->crops_); + } + + return RET_OK; +} + +kernel::LiteKernel *CpuBatchToSpaceFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) BatchToSpaceCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new BatchToSpaceCPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_BatchToSpace, CpuBatchToSpaceFp32KernelCreator) +} // namespace mindspore::kernel + + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.h b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.h new file mode 100644 index 00000000000..3667d47f3e3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BATCH_TO_SPACE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BATCH_TO_SPACE_H_ + +#include +#include "src/lite_kernel.h" + + +namespace mindspore::kernel { +class BatchToSpaceCPUKernel : public LiteKernel { + public: + BatchToSpaceCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs), no_crop_(true) {} + + ~BatchToSpaceCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override; + + private: + bool no_crop_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BATCH_TO_SPACE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bias.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/bias.cc new file mode 100644 index 00000000000..c698cf9ba34 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/bias.cc @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/bias.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/int8/bias_add_int8.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_BiasAdd; + +namespace mindspore::kernel { +int BiasCPUKernel::ReSize() { return RET_OK; } + +int BiasCPUKernel::Run() { + auto in = reinterpret_cast(inputs_.at(0)->Data()); + auto bias = reinterpret_cast(inputs_.at(1)->Data()); + auto out = reinterpret_cast(outputs_.at(0)->Data()); + size_t data_size = inputs_.at(0)->ElementsNum(); + auto tile_in = new float[data_size]; + auto tile_bias = new float[data_size]; + BroadcastAdd(in, bias, tile_in, tile_bias, out, data_size, bias_param_); + delete[] tile_in; + delete[] tile_bias; + return RET_OK; +} + +int BiasCPUKernel::Init() { + auto dims = inputs_[0]->shape(); + MS_ASSERT(dims.size() <= 5); + bias_param_->ndim_ = dims.size(); + for (int i = 0; i < bias_param_->ndim_; i++) { + bias_param_->in_shape0_[i] = dims[i]; + bias_param_->in_shape1_[i] = 1; + bias_param_->out_shape_[i] = dims[i]; + } + bias_param_->in_shape1_[bias_param_->ndim_ - 1] = dims[bias_param_->ndim_ - 1]; + return RET_OK; +} + +kernel::LiteKernel *CpuBiasFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::Context *ctx, const kernel::KernelKey &desc) { + MS_ASSERT(parameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_BiasAdd); + MS_ASSERT(inputs.at(0)); + auto data_type = inputs.at(0)->data_type(); + kernel::LiteKernel *kernel = nullptr; + switch (data_type) { + case kNumberTypeFloat32: + kernel = new (std::nothrow) BiasCPUKernel(parameter, inputs, outputs); + break; + case kNumberTypeInt8: + kernel = new (std::nothrow) BiasAddInt8CPUKernel(parameter, inputs, outputs, ctx); + break; + default: + break; + } + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_BiasAdd, CpuBiasFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bias.h b/mindspore/lite/src/runtime/kernel/arm/fp32/bias.h new file mode 100644 index 00000000000..70808d73904 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/bias.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_H_ +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h" + +namespace mindspore::kernel { +class BiasCPUKernel : public LiteKernel { + public: + BiasCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) { + bias_param_ = reinterpret_cast(parameter); + } + ~BiasCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + ArithmeticParameter *bias_param_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.cc new file mode 100644 index 00000000000..072b91a5181 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/broadcast_to.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_BroadcastTo; + +namespace mindspore::kernel { + +int BroadcastToCPUKernel::Init() { + auto input_shape = inputs_[0]->shape(); + for (size_t i = 0; i < input_shape.size(); ++i) { + shape_info_.input_shape_[i] = input_shape[i]; + } + + shape_info_.input_shape_size_ = static_cast(input_shape.size()); + auto output_shape = outputs_[0]->shape(); + for (size_t i = 0; i < output_shape.size(); ++i) { + shape_info_.output_shape_[i] = output_shape[i]; + } + shape_info_.output_shape_size_ = static_cast(output_shape.size()); + return RET_OK; +} + +int BroadcastToCPUKernel::Run() { + auto input_data = reinterpret_cast(inputs_.at(0)->Data()); + auto output_data = reinterpret_cast(outputs_.at(0)->Data()); + + return BroadcastTo(input_data, &shape_info_, output_data); +} + +kernel::LiteKernel *CpuBroadcastToFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) BroadcastToCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new BroadcastToCPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_BroadcastTo, CpuBroadcastToFp32KernelCreator) +} // namespace mindspore::kernel + + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.h b/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.h new file mode 100644 index 00000000000..8952d4deb6b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BROADCAST_TO_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BROADCAST_TO_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/opclib/fp32/broadcast_to.h" + +namespace mindspore::kernel { +class BroadcastToCPUKernel : public LiteKernel { + public: + BroadcastToCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) : LiteKernel(parameter, inputs, outputs) {} + ~BroadcastToCPUKernel() = default; + + int Init() override; + int ReSize() override { + return 0; + } + int Run() override; + private: + BroadcastShapeInfo shape_info_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BROADCAST_TO_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc new file mode 100644 index 00000000000..7c22b139015 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/cast.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/opclib/fp32/cast.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Cast; + +namespace mindspore::kernel { +namespace { +constexpr int kInputNum = 1; +constexpr int kOutputNum = 1; +const std::vector kSupportInputDataType = {kNumberTypeUInt8, kNumberTypeInt32}; +int CastRun(int thread_id, LiteParallelGroupEnv *penv, void *cdata) { + if (cdata == nullptr) { + MS_LOG(ERROR) << "input cdata is nullptr!"; + return RET_ERROR; + } + + return reinterpret_cast(cdata)->DoCast(thread_id); +} +} // namespace + +int CastCPUKernel::Init() { + data_num_ = inputs_[0]->ElementsNum(); + if (data_num_ == 0) { + return RET_OK; + } + thread_num_ = MSMIN(thread_num_, data_num_); + stride_ = UP_DIV(data_num_, thread_num_); + return RET_OK; +} + +int CastCPUKernel::DoCast(int thread_id) { + auto input = inputs_.at(0); + int data_num = MSMIN(stride_, data_num_ - thread_id * stride_); + if (data_num <= 0) { + return RET_OK; + } + + auto offset = thread_id * stride_; + auto output_data = reinterpret_cast(outputs_.at(0)->Data()); + switch (input->data_type()) { + case kNumberTypeUInt8: + Uint8ToFloat32(reinterpret_cast(input->Data()) + offset, output_data + offset, data_num); + break; + case kNumberTypeInt32: + Int32ToFloat32(reinterpret_cast(input->Data()) + offset, output_data + offset, data_num); + break; + default: + MS_LOG(ERROR) << "Unsupport input data type " << input->data_type(); + return RET_ERROR; + } + return RET_OK; +} + +int CastCPUKernel::Run() { + if (data_num_ == 0) { + return RET_OK; + } + return LiteBackendParallelLaunch(CastRun, this, thread_num_); +} + +kernel::LiteKernel *CpuCastFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + if (ctx == nullptr) { + MS_LOG(ERROR) << "Input context is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) CastCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new CastCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Cast, CpuCastFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.h b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.h new file mode 100644 index 00000000000..02f9565816e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CAST_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CAST_H_ + +#include +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class CastCPUKernel : public LiteKernel { + public: + CastCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs) { + if (ctx != nullptr) { + thread_num_ = ctx->threadNum; + } + } + + ~CastCPUKernel() = default; + + int Init() override; + int ReSize() override { + return 0; + }; + int Run() override; + int DoCast(int thread_id); + private: + uint32_t thread_num_; + uint32_t stride_; + uint32_t data_num_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CAST_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/concat.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/concat.cc new file mode 100644 index 00000000000..bc820ac2fbe --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/concat.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/concat.h" +#include +#include "src/runtime/kernel/arm/opclib/fp32/concat.h" +#include "src/kernel_registry.h" +#include "schema/model_generated.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Concat; + +namespace mindspore::kernel { + int ConcatCPUKernel::Init() { + ConcatBaseCPUKernel::Init(); + schema::Format input0_format = inputs_[0]->GetFormat(); + bool need_convert_format = false; + for (size_t i = 1; i < inputs_.size(); ++i) { + if (inputs_[i]->GetFormat() != input0_format) { + need_convert_format = true; + } + } + if (!need_convert_format) { + outputs_[0]->SetFormat(input0_format); + return RET_OK; + } + MS_LOG(ERROR) << "All input format should be the same!"; + return RET_ERROR; + } + + int ConcatCPUKernel::ReSize() { return RET_OK; } + + int ConcatCPUKernel::Run() { + auto input_num = inputs_.size(); + std::vector inputs_addr(input_num, nullptr); + std::vector inputs_output_shape(input_num + 1, nullptr); + + std::vector > shapes; + for (size_t i = 0; i < input_num; ++i) { + inputs_addr[i] = inputs_[i]->Data(); + shapes.push_back(inputs_[i]->shape()); + inputs_output_shape[i] = shapes[i].data(); + } + auto output_shape = outputs_.at(0)->shape(); + inputs_output_shape[input_num] = output_shape.data(); + auto output_addr = outputs_.at(0)->Data(); + + Concat(reinterpret_cast(inputs_addr.data()), input_num, axis_, inputs_output_shape.data(), + output_shape.size(), output_addr); + return RET_OK; + } +} // namespace mindspore::kernel + + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/concat.h b/mindspore/lite/src/runtime/kernel/arm/fp32/concat.h new file mode 100644 index 00000000000..078921a53dd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/concat.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONCAT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONCAT_H_ + +#include +#include "src/lite_kernel.h" + +#include "include/context.h" +#include "src/runtime/kernel/arm/base/concat_base.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class ConcatCPUKernel : public ConcatBaseCPUKernel { + public: + ConcatCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConcatBaseCPUKernel(parameter, inputs, outputs, ctx) {} + + ~ConcatCPUKernel() = default; + + int Init() override; + + int ReSize() override; + + int Run() override; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONCAT_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc new file mode 100644 index 00000000000..fb6f68a58ee --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc @@ -0,0 +1,201 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/convolution.h" +#include "src/runtime/kernel/arm/opclib/fp32/conv.h" +#include "schema/model_generated.h" +#include "src/kernel_factory.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { +int ConvolutionCPUKernel::InitWeightBias() { + int kernel_h = conv_param_->kernel_h_; + int kernel_w = conv_param_->kernel_w_; + int in_channel = conv_param_->input_channel_; + int out_channel = conv_param_->output_channel_; + int oc8 = UP_DIV(out_channel, C8NUM); + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane; + + // init weight + auto origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float))); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "malloc packed weight failed."; + return RET_ERROR; + } + memset(packed_weight_, 0, pack_weight_size * sizeof(float)); + PackWeightFp32(origin_weight, conv_param_, packed_weight_); + + // init bias + bias_data_ = reinterpret_cast(malloc(oc8 * C8NUM * sizeof(float))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "malloc bias failed."; + return RET_ERROR; + } + memset(bias_data_, 0, oc8 * C8NUM * sizeof(float)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias, out_channel * sizeof(float)); + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + return RET_OK; +} + +int ConvolutionCPUKernel::InitTmpBuffer() { + int kernel_h = conv_param_->kernel_h_; + int kernel_w = conv_param_->kernel_w_; + int in_batch = conv_param_->input_batch_; + int in_channel = conv_param_->input_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int out_channel = conv_param_->output_channel_; + int kernel_plane = kernel_h * kernel_w; + + // malloc packed_inputs + int output_count = conv_param_->output_h_ * conv_param_->output_w_; + int output_tile_count = UP_DIV(output_count, TILE_NUM); + int unit_size = kernel_plane * ic4 * C4NUM; + int packed_input_size = output_tile_count * TILE_NUM * unit_size; + packed_input_ = reinterpret_cast(malloc(in_batch * packed_input_size * sizeof(float))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "malloc packed input failed."; + return RET_ERROR; + } + memset(packed_input_, 0, in_batch * packed_input_size * sizeof(float)); + + size_t nhwc4_input_size = + ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float); + nhwc4_input_ = malloc(nhwc4_input_size); + if (nhwc4_input_ == nullptr) { + MS_LOG(ERROR) << "malloc nhwc4 input failed."; + return RET_ERROR; + } + memset(nhwc4_input_, 0, nhwc4_input_size); + + // tmp out + tmp_output_block_ = reinterpret_cast(malloc(TILE_NUM * out_channel * sizeof(float))); + if (tmp_output_block_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp output block failed."; + return RET_ERROR; + } + return RET_OK; +} + +void ConvolutionCPUKernel::ConfigInputOutput() { + // set output format + auto output_tensor = outputs_.at(kOutputIndex); + output_tensor->SetFormat(schema::Format_NHWC); + + // select trans func for input + auto input_tensor = inputs_.at(kInputIndex); + auto ret = CheckLayout(input_tensor); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Check layout failed."; + return; + } +} + +int ConvolutionCPUKernel::Init() { + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + // init tmp input, output + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + // config input output + ConfigInputOutput(); + return RET_OK; +} + +int ConvolutionCPUKernel::ReSize() { + if (packed_input_ != nullptr) { + free(packed_input_); + } + if (tmp_output_block_ != nullptr) { + free(tmp_output_block_); + } + if (nhwc4_input_ != nullptr) { + free(nhwc4_input_); + } + + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + // init tmp input, output + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionCPUKernel::RunImpl(int task_id) { + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + ConvFp32(reinterpret_cast(nhwc4_input_), packed_input_, packed_weight_, + reinterpret_cast(bias_data_), tmp_output_block_, output_addr, task_id, conv_param_); + return RET_OK; +} + +int ConvolutionImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv = reinterpret_cast(cdata); + auto error_code = conv->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Convolution Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionCPUKernel::Run() { + auto input_tensor = inputs_.at(kInputIndex); + auto ori_input_data = input_tensor->Data(); + int in_batch = conv_param_->input_batch_; + int in_h = conv_param_->input_h_; + int in_w = conv_param_->input_w_; + int in_channel = conv_param_->input_channel_; + convert_func_(ori_input_data, nhwc4_input_, in_batch, in_h * in_w, in_channel); + + int error_code = LiteBackendParallelLaunch(ConvolutionImpl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.h new file mode 100644 index 00000000000..688b981b22e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" + +namespace mindspore::kernel { +class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel { + public: + ConvolutionCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~ConvolutionCPUKernel() override { + if (packed_input_ != nullptr) { + free(packed_input_); + } + if (packed_weight_ != nullptr) { + free(packed_weight_); + } + if (tmp_output_block_ != nullptr) { + free(tmp_output_block_); + } + }; + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + int InitWeightBias(); + int InitTmpBuffer(); + void ConfigInputOutput(); + + private: + float *packed_input_; + float *packed_weight_; + float *tmp_output_block_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc new file mode 100644 index 00000000000..73195f2bd63 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc @@ -0,0 +1,231 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/convolution_1x1.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_MEMORY_FAILED; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +Convolution1x1CPUKernel::~Convolution1x1CPUKernel() { + if (c4_output_ != nullptr) { + free(c4_output_); + c4_output_ = nullptr; + } + if (c4_input_ != nullptr) { + free(c4_input_); + c4_input_ = nullptr; + } + if (pre_trans_input_) { + free(input_ptr_); + input_ptr_ = nullptr; + } + if (tmp_ptr_ != nullptr) { + free(tmp_ptr_); + tmp_ptr_ = nullptr; + } + if (bias_ptr_ != nullptr) { + free(bias_ptr_); + bias_ptr_ = nullptr; + } + if (weight_ptr_ != nullptr) { + free(weight_ptr_); + weight_ptr_ = nullptr; + } + delete matmul_param_; +} + +int Convolution1x1CPUKernel::ReSize() { return RET_OK; } + +void Convolution1x1CPUKernel::InitConv1x1MatmulParam() { + matmul_param_ = new StrassenMatMulParameter(); + matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; + matmul_param_->col_ = UP_DIV(conv_param_->output_channel_, FP32_STRASSEN_UINT); + matmul_param_->deep_ = UP_DIV(conv_param_->input_channel_, FP32_STRASSEN_UINT); + matmul_param_->a_stride_ = matmul_param_->row_ * FP32_STRASSEN_UINT; + matmul_param_->b_stride_ = matmul_param_->deep_ * FP32_STRASSEN_WEIGHT_UINT; + matmul_param_->c_stride_ = matmul_param_->row_ * FP32_STRASSEN_UINT; +} + +int Convolution1x1CPUKernel::InitConv1x1BiasWeight() { + if (inputs_.size() == 3) { + bias_ptr_ = reinterpret_cast(malloc(matmul_param_->col_ * C4NUM * sizeof(float))); + if (bias_ptr_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!"; + return RET_ERROR; + } + memset(bias_ptr_, 0, matmul_param_->col_ * C4NUM * sizeof(float)); + memcpy(bias_ptr_, inputs_[2]->Data(), conv_param_->output_channel_ * sizeof(float)); + } else { + bias_ptr_ = nullptr; + } + + weight_ptr_ = reinterpret_cast( + malloc(matmul_param_->col_ * matmul_param_->deep_ * FP32_STRASSEN_WEIGHT_UINT * sizeof(float))); + if (weight_ptr_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!"; + return RET_ERROR; + } + memset(weight_ptr_, 0, matmul_param_->col_ * matmul_param_->deep_ * FP32_STRASSEN_WEIGHT_UINT * sizeof(float)); + Pack1x1WeightFp32(reinterpret_cast(inputs_[1]->Data()), weight_ptr_, conv_param_); + return RET_OK; +} + +int Convolution1x1CPUKernel::InitConv1x1Param() { + pre_trans_input_ = (conv_param_->pad_h_ != 0 || conv_param_->pad_w_ != 0 || conv_param_->stride_h_ != 1 || + conv_param_->stride_w_ != 1); + if (pre_trans_input_) { + input_ptr_ = reinterpret_cast(malloc(matmul_param_->a_stride_ * matmul_param_->deep_ * sizeof(float))); + if (input_ptr_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc input_ptr_ error!"; + return RET_MEMORY_FAILED; + } + memset(input_ptr_, 0, matmul_param_->a_stride_ * matmul_param_->deep_ * sizeof(float)); + } + + thread_hw_count_ = MSMIN(opParameter->thread_num_, matmul_param_->row_); + thread_hw_stride_ = UP_DIV(matmul_param_->row_, thread_hw_count_); + + thread_oc4_count_ = MSMIN(opParameter->thread_num_, matmul_param_->col_); + thread_oc_stride_ = UP_DIV(matmul_param_->col_, thread_oc4_count_) * C4NUM; + + tmp_ptr_ = reinterpret_cast(malloc(matmul_param_->a_stride_ * matmul_param_->deep_ * sizeof(float))); + if (tmp_ptr_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc tmp_ptr_ error!"; + return RET_MEMORY_FAILED; + } + c4_output_ = reinterpret_cast(malloc(outputs_[0]->ElementsC4Num() / conv_param_->output_batch_ * + sizeof(float))); + if (c4_output_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc c4_output_ error!"; + return RET_MEMORY_FAILED; + } + + c4_input_ = reinterpret_cast(malloc(inputs_[0]->ElementsC4Num() / conv_param_->input_batch_ * + sizeof(float))); + if (c4_input_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc c4_input_ error!"; + return RET_MEMORY_FAILED; + } + return RET_OK; +} + +void Convolution1x1CPUKernel::Pre1x1Trans(float *src_input, float *src_output) { + output_ptr_ = src_output; + PackNHWCToNC4HW4Fp32(src_input, c4_input_, 1, conv_param_->input_h_ * conv_param_->input_w_, + conv_param_->input_channel_); + + if (!pre_trans_input_) { + input_ptr_ = c4_input_; + return; + } + + Conv1x1InputPackFp32(c4_input_, input_ptr_, conv_param_); + return; +} + +int Convolution1x1CPUKernel::Init() { + ConvolutionBaseCPUKernel::Init(); + InitConv1x1MatmulParam(); + + int error_code = InitConv1x1BiasWeight(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Convolution base init failed."; + return error_code; + } + error_code = InitConv1x1Param(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Convolution base init failed."; + return error_code; + } + return RET_OK; +} + +int Convolution1x1CPUKernel::DoStrassen(int task_id) { + matmul_param_->row_ = MSMIN(thread_hw_stride_, matmul_param_->row_ - task_id * thread_hw_stride_); + if (matmul_param_->row_ <= 0) { + return RET_OK; + } + + auto error_code = Conv1x1Fp32(input_ptr_ + task_id * thread_hw_stride_ * C4NUM, weight_ptr_, + c4_output_ + task_id * thread_hw_stride_ * C4NUM, + tmp_ptr_ + task_id * thread_hw_stride_ * matmul_param_->deep_ * C4NUM, *matmul_param_); + if (error_code != 0) { + MS_LOG(ERROR) << "DoStrassen error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; + return RET_OK; +} + +int Convolution1x1StrassenRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv1x1 = reinterpret_cast(cdata); + auto error_code = conv1x1->DoStrassen(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Convolution1x1StrassenRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int Convolution1x1CPUKernel::DoPostFunc(int task_id) { + int cur_oc = MSMIN(thread_oc_stride_, conv_param_->output_channel_ - task_id * thread_oc_stride_); + if (cur_oc <= 0) { + return RET_OK; + } + + PostConvFuncFp32(c4_output_ + matmul_param_->row_ * thread_oc_stride_ * task_id, + output_ptr_ + task_id * thread_oc_stride_, bias_ptr_ + task_id * thread_oc_stride_, cur_oc, + matmul_param_->row_, conv_param_->output_channel_, conv_param_->is_relu_, conv_param_->is_relu6_); + return RET_OK; +} + +int Convolution1x1PostFuncRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv1x1 = reinterpret_cast(cdata); + auto error_code = conv1x1->DoPostFunc(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Convolution1x1PostFuncRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int Convolution1x1CPUKernel::Run() { + auto src_in = reinterpret_cast(inputs_[0]->Data()); + auto src_out = reinterpret_cast(outputs_[0]->Data()); + + for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { + Pre1x1Trans(src_in + batch_index * matmul_param_->deep_ * matmul_param_->a_stride_, + src_out + batch_index * matmul_param_->col_ * matmul_param_->c_stride_); + + int error_code = LiteBackendParallelLaunch(Convolution1x1StrassenRun, this, thread_hw_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv1x1 strassen error error_code[" << error_code << "]"; + return RET_ERROR; + } + + error_code = LiteBackendParallelLaunch(Convolution1x1PostFuncRun, this, thread_oc4_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv1x1 post function error error_code[" << error_code << "]"; + return RET_ERROR; + } + } + return RET_OK; +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.h new file mode 100644 index 00000000000..9d9c951c289 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.h @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_ + +#include +#include "src/lite_kernel.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/winograd_transform.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" +#include "src/runtime/kernel/arm/opclib/fp32/conv.h" +#include "src/runtime/kernel/arm/opclib/fp32/common_func.h" + +namespace mindspore::kernel { +class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { + public: + Convolution1x1CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~Convolution1x1CPUKernel(); + int Init() override; + int Run() override; + int ReSize() override; + + public: + int DoStrassen(int task_id); + int DoPostFunc(int task_id); + + private: + int InitConv1x1Param(); + int InitConv1x1BiasWeight(); + void InitConv1x1MatmulParam(); + void Pre1x1Trans(float *src_input, float *src_output); + + private: + StrassenMatMulParameter *matmul_param_ = nullptr; + bool pre_trans_input_ = false; + int thread_count_ = 0; + int thread_hw_count_ = 0; + int thread_hw_stride_ = 0; + int thread_oc4_count_ = 0; + int thread_oc_stride_ = 0; + float *bias_ptr_ = nullptr; + float *weight_ptr_ = nullptr; + float *tmp_ptr_ = nullptr; + float *c4_input_ = nullptr; + float *c4_output_ = nullptr; + float *input_ptr_ = nullptr; + float *output_ptr_ = nullptr; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.cc new file mode 100644 index 00000000000..3bf8682a37e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.cc @@ -0,0 +1,237 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/convolution_3x3.h" +#include "src/runtime/kernel/arm/opclib/fp32/conv.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { +void ProcessFilter(float *origin_weight, float *dst_weight, ConvParameter *conv_param) { + auto input_channel = conv_param->input_channel_; + auto output_channel = conv_param->output_channel_; + auto kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + int iC4 = UP_DIV(input_channel, C4NUM); + int oc8 = UP_DIV(output_channel, C8NUM); + + size_t tmp_size = oc8 * C8NUM * iC4 * C4NUM * kernel_plane * sizeof(float); + auto tmp_addr = reinterpret_cast(malloc(tmp_size)); + if (tmp_addr == nullptr) { + MS_LOG(ERROR) << "malloc tmp_addr failed."; + return; + } + memset(tmp_addr, 0, tmp_size); + + PackNHWCToNC4HW4Fp32(origin_weight, tmp_addr, output_channel, kernel_plane, input_channel); + Conv3x3Fp32FilterTransform(tmp_addr, dst_weight, iC4, output_channel, kernel_plane); + + free(tmp_addr); +} + +int Convolution3x3CPUKernel::InitWeightBias() { + auto input_channel = conv_param_->input_channel_; + auto output_channel = conv_param_->output_channel_; + int iC4 = UP_DIV(input_channel, C4NUM); + int oC4 = UP_DIV(output_channel, C4NUM); + int oC8 = UP_DIV(output_channel, C8NUM); + int k_plane = 16; + // init weight + size_t transformed_size = iC4 * C4NUM * oC8 * C8NUM * k_plane * sizeof(float); + transformed_filter_addr_ = reinterpret_cast(malloc(transformed_size)); + if (transformed_filter_addr_ == nullptr) { + MS_LOG(ERROR) << "malloc transformed filter addr failed."; + return RET_ERROR; + } + memset(transformed_filter_addr_, 0, transformed_size); + auto weight_data = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + ProcessFilter(weight_data, transformed_filter_addr_, conv_param_); + + // init bias + size_t new_bias_size = oC4 * C4NUM * sizeof(float); + bias_data_ = reinterpret_cast(malloc(new_bias_size)); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "malloc bias data failed."; + return RET_ERROR; + } + memset(bias_data_, 0, new_bias_size); + if (inputs_.size() == kInputSize2) { + auto ori_bias_addr = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias_addr, output_channel * sizeof(float)); + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + return RET_OK; +} + +int Convolution3x3CPUKernel::InitTmpBuffer() { + int iC4 = UP_DIV(conv_param_->input_channel_, C4NUM); + int oC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int k_plane = 16; + // todo + size_t tile_buffer_size = thread_count_ * TILE_NUM * k_plane * iC4 * C4NUM * sizeof(float); + tile_buffer_ = reinterpret_cast(malloc(tile_buffer_size)); + if (tile_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc tile buffer failed."; + return RET_ERROR; + } + memset(tile_buffer_, 0, tile_buffer_size); + + size_t block_unit_buffer_size = thread_count_ * k_plane * C4NUM * sizeof(float); + block_unit_buffer_ = reinterpret_cast(malloc(block_unit_buffer_size)); + if (block_unit_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc block_unit_buffer_ failed."; + return RET_ERROR; + } + memset(block_unit_buffer_, 0, block_unit_buffer_size); + + size_t tmp_dst_buffer_size = thread_count_ * TILE_NUM * k_plane * oC4 * C4NUM * sizeof(float); + tmp_dst_buffer_ = reinterpret_cast(malloc(tmp_dst_buffer_size)); + if (tmp_dst_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_dst_buffer_ failed."; + return RET_ERROR; + } + memset(tmp_dst_buffer_, 0, tmp_dst_buffer_size); + + size_t nhwc4_input_size = + iC4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float); + nhwc4_input_ = malloc(nhwc4_input_size); + if (nhwc4_input_ == nullptr) { + MS_LOG(ERROR) << "malloc nhwc4_input_ failed."; + return RET_ERROR; + } + memset(nhwc4_input_, 0, nhwc4_input_size); + + size_t nc4hw4_out_size = + oC4 * C4NUM * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * sizeof(float); + nc4hw4_out_ = reinterpret_cast(malloc(nc4hw4_out_size)); + if (nc4hw4_out_ == nullptr) { + MS_LOG(ERROR) << "malloc nc4hw4_out_ failed."; + return RET_ERROR; + } + memset(nc4hw4_out_, 0, nc4hw4_out_size); + tmp_buffer_address_list_[0] = tile_buffer_; + tmp_buffer_address_list_[1] = block_unit_buffer_; + tmp_buffer_address_list_[2] = tmp_dst_buffer_; + tmp_buffer_address_list_[3] = nc4hw4_out_; + return RET_OK; +} + +void Convolution3x3CPUKernel::ConfigInputOutput() { + auto output_tensor = outputs_.at(kOutputIndex); + output_tensor->SetFormat(schema::Format_NHWC); + + auto input_tensor = inputs_.at(kInputIndex); + auto ret = CheckLayout(input_tensor); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Check layout failed."; + return; + } +} + +int Convolution3x3CPUKernel::Init() { + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + ConfigInputOutput(); + return RET_OK; +} + +int Convolution3x3CPUKernel::ReSize() { + if (tile_buffer_ != nullptr) { + free(tile_buffer_); + } + if (block_unit_buffer_ != nullptr) { + free(block_unit_buffer_); + } + if (tmp_dst_buffer_ != nullptr) { + free(tmp_dst_buffer_); + } + if (nhwc4_input_ != nullptr) { + free(nhwc4_input_); + } + if (nc4hw4_out_ != nullptr) { + free(nc4hw4_out_); + } + + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int Convolution3x3CPUKernel::RunImpl(int task_id) { + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + Conv3x3Fp32(reinterpret_cast(nhwc4_input_), transformed_filter_addr_, reinterpret_cast(bias_data_), + output_addr, tmp_buffer_address_list_, task_id, conv_param_); + return RET_OK; +} + +int Convolution3x3Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv3x3 = reinterpret_cast(cdata); + auto error_code = conv3x3->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Convolution3x3 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int Convolution3x3CPUKernel::Run() { + auto input_tensor = inputs_.at(kInputIndex); + auto ori_input_data = input_tensor->Data(); + int in_batch = conv_param_->input_batch_; + int in_h = conv_param_->input_h_; + int in_w = conv_param_->input_w_; + int in_channel = conv_param_->input_channel_; + convert_func_(ori_input_data, nhwc4_input_, in_batch, in_h * in_w, in_channel); + + int error_code = LiteBackendParallelLaunch(Convolution3x3Impl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv3x3 error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.h new file mode 100644 index 00000000000..9d9880ea582 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_3X3_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_3X3_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/opclib/winograd_transform.h" + +namespace mindspore::kernel { +class Convolution3x3CPUKernel : public ConvolutionBaseCPUKernel { + public: + Convolution3x3CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~Convolution3x3CPUKernel() override { + if (transformed_filter_addr_ != nullptr) { + free(transformed_filter_addr_); + } + if (tile_buffer_ != nullptr) { + free(tile_buffer_); + } + if (block_unit_buffer_ != nullptr) { + free(block_unit_buffer_); + } + if (tmp_dst_buffer_ != nullptr) { + free(tmp_dst_buffer_); + } + if (nc4hw4_out_ != nullptr) { + free(nc4hw4_out_); + } + }; + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + int InitWeightBias(); + int InitTmpBuffer(); + void ConfigInputOutput(); + + private: + float *transformed_filter_addr_; + float *tile_buffer_; + float *block_unit_buffer_; + float *tmp_dst_buffer_; + float *nc4hw4_out_; + TmpBufferAddress tmp_buffer_address_list_[4]; +}; +void ProcessFilter(float *origin_weight, float *dst_weight, ConvParameter *conv_param); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_3X3_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc new file mode 100644 index 00000000000..6fcd243330d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc @@ -0,0 +1,149 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/convolution_depthwise.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DepthwiseConv2D; + +namespace mindspore::kernel { +int ConvolutionDepthwiseCPUKernel::Init() { + // conv base init + ConvolutionBaseCPUKernel::Init(); + + // init sliding window param + sliding = new SlidingWindowParam; + InitSlidingParam(sliding, conv_param_, C4NUM); + + // pack input function: convert_func_ + auto input_tensor = inputs_[kInputIndex]; + auto data_type = input_tensor->data_type(); + auto input_format = input_tensor->GetFormat(); + schema::Format execute_format = schema::Format_NHWC4; + if (input_format != execute_format) { + convert_func_ = LayoutTransform(data_type, input_format, execute_format); + if (convert_func_ == nullptr) { + MS_LOG(ERROR) << "layout convert func is nullptr."; + return RET_ERROR; + } + } + + // init weight: o, h, w, i; o == group, i == 1 + auto weight_tensor = inputs_[kWeightIndex]; + auto origin_weight = reinterpret_cast(weight_tensor->Data()); + int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int pack_weight_size = C4NUM * OC4 * conv_param_->kernel_h_ * conv_param_->kernel_w_; + + packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float))); + memset(packed_weight_, 0, pack_weight_size * sizeof(float)); + PackNCHWToNC4HW4Fp32(origin_weight, packed_weight_, 1, conv_param_->kernel_h_ * conv_param_->kernel_w_, + conv_param_->output_channel_); + + // init bias + bias_data_ = reinterpret_cast(malloc(C4NUM * OC4 * sizeof(float))); + memset(bias_data_, 0, C4NUM * OC4 * sizeof(float)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias, conv_param_->output_channel_ * sizeof(float)); + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + + // init threadNum; + conv_param_->thread_num_ = MSMIN(thread_count_, OC4); + ReSize(); + return RET_OK; +} + +int ConvolutionDepthwiseCPUKernel::ReSize() { + // malloc pack input buffer + if (convert_func_ != nullptr) { + int IC4 = UP_DIV(conv_param_->input_channel_, C4NUM); + int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * IC4; + packed_input_ = reinterpret_cast(malloc(pack_input_size * sizeof(float))); + memset(packed_input_, 0, pack_input_size * sizeof(float)); + } + + // malloc tmp output buffer + if (conv_param_->output_channel_ % C4NUM != 0) { + need_align_ = true; + int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * OC4; + packed_output_ = reinterpret_cast(malloc(pack_output_size * sizeof(float))); + memset(packed_output_, 0, pack_output_size * sizeof(float)); + } + return RET_OK; +} + +int ConvolutionDepthwiseCPUKernel::Execute(int task_id) { + ConvDwC4Fp32(packed_output_, packed_input_, packed_weight_, reinterpret_cast(bias_data_), conv_param_, + sliding, task_id); + return RET_OK; +} + +int ConvDwRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv_dw = reinterpret_cast(cdata); + auto ret = conv_dw->Execute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionDepthwiseRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionDepthwiseCPUKernel::Run() { + if (conv_param_->input_channel_ != conv_param_->output_channel_) { + MS_LOG(ERROR) << "Only support input channel equals output channel."; + return RET_ERROR; + } + auto input_tensor = inputs_.at(kInputIndex); + auto input_addr = reinterpret_cast(input_tensor->Data()); + + // pack input: to nhwc4 + if (convert_func_ != nullptr) { + convert_func_(input_addr, packed_input_, conv_param_->input_batch_, conv_param_->input_h_ * conv_param_->input_w_, + conv_param_->input_channel_); + } else { + packed_input_ = input_addr; + } + + output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + memset(output_addr, 0, outputs_.at(kOutputIndex)->ElementsNum() * sizeof(float)); + if (!need_align_) { + packed_output_ = output_addr; + } + + auto ret = LiteBackendParallelLaunch(ConvDwRun, this, conv_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvDwRun error: error_code[" << ret << "]"; + return RET_ERROR; + } + + if (need_align_) { + PackNHWC4ToNHWCFp32(packed_output_, output_addr, conv_param_->output_batch_, + conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); + } + return RET_OK; +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.h new file mode 100644 index 00000000000..88266b88dd5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/opclib/fp32/conv_depthwise.h" + +namespace mindspore::kernel { +class ConvolutionDepthwiseCPUKernel : public ConvolutionBaseCPUKernel { + public: + ConvolutionDepthwiseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~ConvolutionDepthwiseCPUKernel() override { + delete sliding; + free(packed_weight_); + if (convert_func_ != nullptr) { + free(packed_input_); + } + if (need_align_) { + free(packed_output_); + } + }; + + int Init() override; + int ReSize() override; + int Run() override; + + int Execute(int task_id); + + private: + SlidingWindowParam *sliding; + float *packed_weight_; + float *packed_input_; + float *packed_output_; + float *output_addr; + bool need_align_ = false; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc new file mode 100644 index 00000000000..3a021d2b662 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc @@ -0,0 +1,338 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/convolution_winograd.h" +#include "src/runtime/kernel/arm/opclib/fp32/conv.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { +void WinogradFilterTransform(const float *weight_data, Matrix *trans_weight, int kernel_unit, int input_unit, + ConvParameter *conv_param) { + // original weight format : ohwi + auto channel_in = conv_param->input_channel_; + auto channel_out = conv_param->output_channel_; + int input_unit_square = input_unit * input_unit; + + // generate matrix_G && matrix_GT + auto matrix_g = TransformMatrixGenerator(input_unit, kernel_unit); + auto matrix_gt = TransformMatrixGenerator(kernel_unit, input_unit); + ChooseMatrixG(matrix_g, matrix_gt); + auto matrix_g_data = reinterpret_cast(matrix_g->GetData()); + auto matrix_gt_data = reinterpret_cast(matrix_gt->GetData()); + + // trans_filter = G*g*GT (g represents weight_data) + // separate into two steps ===> tmp = G*g ===> out = tmp * GT + auto tmp_weight_data = reinterpret_cast(malloc(kernel_unit * kernel_unit * sizeof(float))); + auto tmp_data = reinterpret_cast(malloc(input_unit * kernel_unit * sizeof(float))); + auto trans_out_data = reinterpret_cast(malloc(input_unit * input_unit * sizeof(float))); + bool row = true; + auto trans_weight_data = reinterpret_cast(trans_weight->GetData()); + std::vector strides = trans_weight->GetStride(); + + int kernel_plane_stride = channel_in; + for (int i = 0; i < channel_out; i++) { + int oc8_block = i / C8NUM; + int oc8_res = i % C8NUM; + int input_oz_offset = i * kernel_unit * kernel_unit * channel_in; + int output_oz_offset = oc8_block * strides[1] * input_unit * input_unit + oc8_res; + for (int j = 0; j < channel_in; j++) { + int ic4_block = j / C4NUM; + int ic4_res = j % C4NUM; + int input_iz_offset = input_oz_offset + j; + int output_iz_offset = output_oz_offset + ic4_block * strides[2] + ic4_res * strides[3]; + for (int k = 0; k < kernel_unit * kernel_unit; k++) { + int input_xy_offset = input_iz_offset + k * kernel_plane_stride; + tmp_weight_data[k] = *(weight_data + input_xy_offset); + } + // now we only support row-major matrix-multiply + // tmp = G * g + MatrixMultiply(matrix_g_data, tmp_weight_data, tmp_data, input_unit, kernel_unit, kernel_unit, row); + // out = tmp * GT + MatrixMultiply(tmp_data, matrix_gt_data, trans_out_data, input_unit, kernel_unit, input_unit, row); + + for (int z = 0; z < input_unit_square; z++) { + int output_xy_offset = output_iz_offset + z * strides[1]; + *(trans_weight_data + output_xy_offset) = trans_out_data[z]; + } + } + } + free(tmp_weight_data); + free(tmp_data); + free(trans_out_data); + delete matrix_g; + delete matrix_gt; +} + +int ConvolutionWinogradCPUKernel::InitWeightBias() { + int output_channel = conv_param_->output_channel_; + int oc4 = UP_DIV(output_channel, C4NUM); + + // init weight + auto ret = MallocFilterMatrix(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Malloc filter matrix failed."; + return RET_ERROR; + } + auto weight_tensor = inputs_.at(kWeightIndex); + auto weight_data = reinterpret_cast(weight_tensor->Data()); + WinogradFilterTransform(weight_data, trans_weight_, kernel_unit_, input_unit_, conv_param_); + + // init bias + size_t new_bias_size = oc4 * C4NUM * sizeof(float); + bias_data_ = reinterpret_cast(malloc(new_bias_size)); + memset(bias_data_, 0, new_bias_size); + if (inputs_.size() == kInputSize2) { + auto ori_bias_addr = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias_addr, output_channel * sizeof(float)); + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + return RET_OK; +} + +int ConvolutionWinogradCPUKernel::MallocFilterMatrix() { + int channel_in = conv_param_->input_channel_; + int channel_out = conv_param_->output_channel_; + int ic4 = UP_DIV(channel_in, BLOCK); + int oc8 = UP_DIV(channel_out, C8NUM); + + // set data + auto trans_matrix_data_size = input_unit_ * input_unit_ * ic4 * oc8 * C4NUM * C8NUM * sizeof(float); + auto matrix_buffer = malloc(trans_matrix_data_size); + if (matrix_buffer == nullptr) { + MS_LOG(ERROR) << "malloc matrix_buffer failed."; + return RET_ERROR; + } + memset(matrix_buffer, 0, trans_matrix_data_size); + trans_weight_ = new Matrix(); + trans_weight_->SetData(matrix_buffer); + trans_weight_->SetNDim(5); + + std::vector shapes; + std::vector strides; + // set shape + shapes.push_back(input_unit_ * input_unit_); + shapes.push_back(oc8); + shapes.push_back(ic4); + shapes.push_back(C4NUM); + shapes.push_back(C8NUM); + // set stride + for (int i = 0; i < 4; i++) { + int stride = 1; + for (int j = i + 1; j < 5; j++) { + stride *= shapes[j]; + } + strides.push_back(stride); + } + trans_weight_->SetShape(shapes); + trans_weight_->SetStride(strides); + return RET_OK; +} + +int ConvolutionWinogradCPUKernel::InitTmpBuffer() { + int channel_in = conv_param_->input_channel_; + int channel_out = conv_param_->output_channel_; + int output_h = conv_param_->output_h_; + int output_w = conv_param_->output_w_; + int ic4 = UP_DIV(channel_in, C4NUM); + int oc4 = UP_DIV(channel_out, C4NUM); + + size_t tile_buffer_size = thread_count_ * TILE_NUM * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float); + trans_input_ = reinterpret_cast(malloc(tile_buffer_size)); + if (trans_input_ == nullptr) { + MS_LOG(ERROR) << "malloc trans_input_ failed."; + return RET_ERROR; + } + memset(trans_input_, 0, tile_buffer_size); + + gemm_out_ = reinterpret_cast( + malloc(thread_count_ * TILE_NUM * input_unit_ * input_unit_ * oc4 * C4NUM * sizeof(float))); + if (gemm_out_ == nullptr) { + MS_LOG(ERROR) << "malloc gemm_out_ failed."; + return RET_ERROR; + } + + int out_w_block = UP_DIV(output_w, output_unit_); + int out_h_block = UP_DIV(output_h, output_unit_); + tmp_out_data_ = reinterpret_cast( + malloc(out_w_block * out_h_block * output_unit_ * output_unit_ * oc4 * C4NUM * sizeof(float))); + if (tmp_out_data_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_out_data_ failed."; + return RET_ERROR; + } + + tmp_data_ = reinterpret_cast(malloc(C4NUM * input_unit_ * input_unit_ * sizeof(float))); + if (tmp_data_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_data_ failed."; + return RET_ERROR; + } + memset(tmp_data_, 0, C4NUM * input_unit_ * input_unit_ * sizeof(float)); + + tmp_buffer_address_list_[0] = trans_input_; + tmp_buffer_address_list_[1] = gemm_out_; + tmp_buffer_address_list_[2] = tmp_out_data_; + tmp_buffer_address_list_[3] = tmp_data_; + + size_t nhwc4_input_size = + ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float); + nhwc4_input_ = malloc(nhwc4_input_size); + if (nhwc4_input_ == nullptr) { + MS_LOG(ERROR) << "malloc nhwc4_input_ failed."; + return RET_ERROR; + } + memset(nhwc4_input_, 0, nhwc4_input_size); + return RET_OK; +} + +int ConvolutionWinogradCPUKernel::ConfigInputOutput() { + auto input_tensor = inputs_.at(kInputIndex); + auto ret = CheckLayout(input_tensor); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Check layout failed."; + return RET_ERROR; + } + auto output_tensor = outputs_.at(kOutputIndex); + output_tensor->SetFormat(schema::Format_NHWC); + + // choose input transformer function (4x4 unit or 8x8 unit) + input_trans_func_ = GetInputTransFunc(input_unit_); + if (input_trans_func_ == nullptr) { + MS_LOG(ERROR) << "Get input_trans_func failed."; + return RET_ERROR; + } + output_trans_func_ = GetOutputTransFunc(input_unit_, output_unit_); + if (output_trans_func_ == nullptr) { + MS_LOG(ERROR) << "Get output_trans_func_ failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionWinogradCPUKernel::Init() { + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + kernel_unit_ = conv_param_->kernel_h_; + input_unit_ = output_unit_ + kernel_unit_ - 1; + conv_param_->input_unit_ = input_unit_; + conv_param_->output_unit_ = output_unit_; + + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + // malloc tmp buffer + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + ret = ConfigInputOutput(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConfigInputOutput failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionWinogradCPUKernel::ReSize() { + if (tmp_data_ != nullptr) { + free(tmp_data_); + } + if (trans_input_ != nullptr) { + free(trans_input_); + } + if (gemm_out_ != nullptr) { + free(gemm_out_); + } + if (tmp_out_data_ != nullptr) { + free(tmp_out_data_); + } + if (nhwc4_input_ != nullptr) { + free(nhwc4_input_); + } + + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + kernel_unit_ = conv_param_->kernel_h_; + input_unit_ = output_unit_ + kernel_unit_ - 1; + conv_param_->input_unit_ = input_unit_; + conv_param_->output_unit_ = output_unit_; + + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + ret = ConfigInputOutput(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConfigInputOutput failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionWinogradCPUKernel::RunImpl(int task_id) { + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + ConvWinogardFp32(reinterpret_cast(nhwc4_input_), reinterpret_cast(trans_weight_->GetData()), + reinterpret_cast(bias_data_), output_addr, tmp_buffer_address_list_, task_id, + conv_param_, input_trans_func_, output_trans_func_); + return RET_OK; +} + +int ConvolutionWinogradImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv = reinterpret_cast(cdata); + auto error_code = conv->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "ConvolutionWinograd Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionWinogradCPUKernel::Run() { + auto input_tensor = inputs_.at(kInputIndex); + auto ori_input_data = input_tensor->Data(); + int in_batch = conv_param_->input_batch_; + int in_h = conv_param_->input_h_; + int in_w = conv_param_->input_w_; + int in_channel = conv_param_->input_channel_; + convert_func_(ori_input_data, nhwc4_input_, in_batch, in_h * in_w, in_channel); + + int error_code = LiteBackendParallelLaunch(ConvolutionWinogradImpl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv winograd error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h new file mode 100644 index 00000000000..6518db604db --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/opclib/winograd_transform.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/base/matrix.h" + +namespace mindspore::kernel { +class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { + public: + ConvolutionWinogradCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx, int output_unit) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx), output_unit_(output_unit) {} + ~ConvolutionWinogradCPUKernel() override { + if (tmp_data_ != nullptr) { + free(tmp_data_); + } + if (trans_input_ != nullptr) { + free(trans_input_); + } + if (gemm_out_ != nullptr) { + free(gemm_out_); + } + if (tmp_out_data_ != nullptr) { + free(tmp_out_data_); + } + delete trans_weight_; + }; + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + int InitWeightBias(); + int MallocFilterMatrix(); + int InitTmpBuffer(); + int ConfigInputOutput(); + + private: + int kernel_unit_; + int input_unit_; + int output_unit_; + float *tmp_data_; + float *trans_input_; + float *gemm_out_; + float *tmp_out_data_; + Matrix *trans_weight_; + InputTransformUnitFunc input_trans_func_; + OutputTransformUnitFunc output_trans_func_; + TmpBufferAddress tmp_buffer_address_list_[5]; +}; +void WinogradFilterTransform(const float *weight_data, Matrix *trans_weight, int kernel_unit, int input_unit, + ConvParameter *conv_param); +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc new file mode 100644 index 00000000000..f974e6df5d0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/crop.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/opclib/fp32/crop.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_FORMAT_ERR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Crop; + +namespace mindspore::kernel { +int CropCPUKernel::Init() { + schema::Format input0_format = inputs_[0]->GetFormat(); + if (input0_format != schema::Format_NC4HW4) { + outputs_[0]->SetFormat(input0_format); + return RET_OK; + } + convert_function_ = LayoutTransform(inputs_[0]->data_type(), inputs_[0]->GetFormat(), schema::Format_NHWC); + if (convert_function_ == nullptr) { + MS_LOG(ERROR) << "Can not convert format " << inputs_[0]->GetFormat() << " to " << schema::Format_NHWC; + return RET_ERROR; + } + auto packed_input_size = inputs_[0]->Channel() * inputs_[0]->Batch() * inputs_[0]->Height() * inputs_[0]->Width(); + packed_input_ = reinterpret_cast(malloc(packed_input_size * sizeof(float))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "malloc memory fail!"; + return RET_ERROR; + } + memset(packed_input_, 0, packed_input_size * sizeof(float)); + return RET_OK; +} + +int CropCPUKernel::Run() { + auto input = inputs_[0]; + auto output = outputs_[0]; + float *input_data = reinterpret_cast(input->Data()); + if (convert_function_ != nullptr) { + convert_function_(input_data, packed_input_, inputs_[0]->Batch(), inputs_[0]->Height() * inputs_[0]->Width(), + inputs_[0]->Channel()); + } else { + packed_input_ = input_data; + } + float *output_data = reinterpret_cast(output->Data()); + Crop4D(input_data, output_data, input->shape().data(), output->shape().data(), + reinterpret_cast(opParameter)); + return RET_OK; +} + +kernel::LiteKernel *CpuCropFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) CropCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new CropCPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Crop, CpuCropFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/crop.h b/mindspore/lite/src/runtime/kernel/arm/fp32/crop.h new file mode 100644 index 00000000000..b45a97c0810 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/crop.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CROP_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CROP_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/base/layout_transform.h" + +namespace mindspore::kernel { +class CropCPUKernel : public LiteKernel { + public: + CropCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs), packed_input_(nullptr), convert_function_(nullptr) {} + ~CropCPUKernel() { + if (packed_input_ != nullptr) { + free(packed_input_); + packed_input_ = nullptr; + } + } + + int Init() override; + int ReSize() override { return 0; } + int Run() override; + + private: + float *packed_input_; + LayoutConvertor convert_function_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CROP_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc new file mode 100644 index 00000000000..c7e4323acea --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc @@ -0,0 +1,227 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/deconvolution.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_NULL_PTR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DeConv2D; + +namespace mindspore::kernel { +DeConvolutionCPUKernel::~DeConvolutionCPUKernel() { + if (weight_ptr_ != nullptr) { + free(weight_ptr_); + weight_ptr_ = nullptr; + } + if (tmp_output_ != nullptr) { + free(tmp_output_); + tmp_output_ = nullptr; + } + + if (tmp_buffer_ != nullptr) { + free(tmp_buffer_); + tmp_buffer_ = nullptr; + } + if (c4_input_ != nullptr) { + free(c4_input_); + c4_input_ = nullptr; + } + if (c4_output_ != nullptr) { + free(c4_output_); + c4_output_ = nullptr; + } + return; +} + +int DeConvolutionCPUKernel::ReSize() { return 0; } + +int DeConvolutionCPUKernel::InitWeightBias() { + if (inputs_.size() == 3) { + bias_data_ = malloc(UP_ROUND(conv_param_->output_channel_, C4NUM) * sizeof(float)); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "deconv malloc bias_data_ error!"; + return RET_ERROR; + } + memset(bias_data_, 0, UP_ROUND(conv_param_->output_channel_, C4NUM) * sizeof(float)); + memcpy(bias_data_, inputs_[2]->Data(), conv_param_->output_channel_ * sizeof(float)); + } else { + bias_data_ = nullptr; + } + + size_t weight_pack_size = conv_param_->kernel_w_ * conv_param_->kernel_h_ * + UP_ROUND(conv_param_->output_channel_, C4NUM) * + UP_ROUND(conv_param_->input_channel_, C4NUM) * sizeof(float); + weight_ptr_ = reinterpret_cast(malloc(weight_pack_size)); + if (weight_ptr_ == nullptr) { + MS_LOG(ERROR) << "deconv malloc weight_ptr_ error!"; + return RET_ERROR; + } + memset(weight_ptr_, 0, weight_pack_size); + PackDeConvWeightFp32(reinterpret_cast(inputs_[1]->Data()), weight_ptr_, conv_param_->input_channel_, + conv_param_->output_channel_, conv_param_->kernel_w_ * conv_param_->kernel_h_); + return RET_OK; +} + +int DeConvolutionCPUKernel::InitParam() { + matmul_param_ = new StrassenMatMulParameter(); + matmul_param_->row_ = conv_param_->input_h_ * conv_param_->input_w_; + matmul_param_->deep_ = UP_DIV(conv_param_->input_channel_, C4NUM); + matmul_param_->col_ = UP_DIV(conv_param_->output_channel_, 4) * conv_param_->kernel_w_ * conv_param_->kernel_h_; + matmul_param_->a_stride_ = matmul_param_->row_ * C4NUM; + matmul_param_->b_stride_ = matmul_param_->deep_ * C4NUM * C4NUM; + matmul_param_->c_stride_ = matmul_param_->row_ * C4NUM; + + thread_hw_count_ = MSMIN(opParameter->thread_num_, matmul_param_->row_); + thread_hw_stride_ = UP_DIV(matmul_param_->row_, thread_hw_count_); + + thread_co4_count_ = MSMIN(opParameter->thread_num_, UP_DIV(conv_param_->output_channel_, C4NUM)); + thread_co_stride_ = UP_DIV(UP_DIV(conv_param_->output_channel_, C4NUM), thread_co4_count_) * C4NUM; + + tmp_buffer_ = + reinterpret_cast(malloc(matmul_param_->a_stride_ * matmul_param_->deep_ * C4NUM * sizeof(float))); + if (tmp_buffer_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc tmp_buffer_ error!"; + return RET_ERROR; + } + + tmp_output_ = reinterpret_cast(malloc(matmul_param_->row_ * matmul_param_->col_ * C4NUM * sizeof(float))); + if (tmp_output_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc tmp_output_ error!"; + return RET_ERROR; + } + + c4_input_ = + reinterpret_cast(malloc(inputs_[0]->ElementsC4Num() / conv_param_->input_batch_ * sizeof(float))); + if (c4_input_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc c4_input_ error!"; + return RET_NULL_PTR; + } + + c4_output_ = + reinterpret_cast(malloc(outputs_[0]->ElementsC4Num() / conv_param_->output_batch_ * sizeof(float))); + if (c4_output_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 Malloc c4_output_ error!"; + return RET_NULL_PTR; + } + return RET_OK; +} + +int DeConvFp32Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto deconv = reinterpret_cast(cdata); + auto error_code = deconv->DoDeconv(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "DeConvFp32Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +int DeConvFp32PostRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto deconv = reinterpret_cast(cdata); + auto error_code = deconv->DoPostFunc(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "DeConvFp32PostRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +int DeConvolutionCPUKernel::DoDeconv(int task_id) { + matmul_param_->row_ = MSMIN(thread_hw_stride_, matmul_param_->row_ - task_id * thread_hw_stride_); + if (matmul_param_->row_ <= 0) { + return RET_OK; + } + + int error_code = DeConvFp32(c4_input_ + task_id * thread_hw_stride_ * C4NUM, weight_ptr_, + tmp_output_ + task_id * thread_hw_stride_ * C4NUM, + tmp_buffer_ + task_id * thread_hw_stride_ * matmul_param_->deep_ * C4NUM, *matmul_param_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "DeConvFp32 error! error code: " << error_code; + return error_code; + } + + matmul_param_->row_ = conv_param_->input_h_ * conv_param_->input_w_; + return RET_OK; +} + +int DeConvolutionCPUKernel::DoPostFunc(int task_id) { + int input_plane = conv_param_->input_h_ * conv_param_->input_w_; + int kernel_plane = conv_param_->kernel_w_ * conv_param_->kernel_h_; + int output_plane = conv_param_->output_h_ * conv_param_->output_w_; + + int cur_oc = MSMIN(thread_co_stride_, conv_param_->output_channel_ - task_id * thread_co_stride_); + if (cur_oc <= 0) { + return RET_OK; + } + + DeConvPostFp32(tmp_output_ + thread_co_stride_ * task_id * input_plane * kernel_plane, + c4_output_ + thread_co_stride_ * task_id * output_plane, output_ptr_ + thread_co_stride_ * task_id, + reinterpret_cast(bias_data_) + thread_co_stride_ * task_id, cur_oc, input_plane, kernel_plane, + output_plane, conv_param_); + return RET_OK; +} + +int DeConvolutionCPUKernel::Init() { + int error_code = ConvolutionBaseCPUKernel::Init(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Conv base init error!"; + return error_code; + } + + error_code = InitParam(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv InitParam error!"; + return error_code; + } + + error_code = InitWeightBias(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv InitWeightBias error!"; + return error_code; + } + return RET_OK; +} + +int DeConvolutionCPUKernel::Run() { + float *src_in = reinterpret_cast(inputs_[0]->Data()); + float *src_out = reinterpret_cast(outputs_[0]->Data()); + + for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { + input_ptr_ = src_in + batch_index * conv_param_->input_w_ * conv_param_->input_h_ * conv_param_->input_channel_; + output_ptr_ = + src_out + batch_index * conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->output_channel_; + + PackNHWCToNC4HW4Fp32(input_ptr_, c4_input_, 1, conv_param_->input_h_ * conv_param_->input_w_, + conv_param_->input_channel_); + + int error_code = LiteBackendParallelLaunch(DeConvFp32Run, this, thread_hw_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv fp32 run error! error_code[" << error_code << "]"; + return RET_ERROR; + } + + error_code = LiteBackendParallelLaunch(DeConvFp32PostRun, this, thread_co4_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv fp32 postrun error! error_code[" << error_code << "]"; + return RET_ERROR; + } + } + return RET_OK; +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h new file mode 100644 index 00000000000..bd592961d9b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_H_ + +#include +#include "src/lite_kernel.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "schema/model_generated.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/opclib/fp32/deconv.h" + +namespace mindspore::kernel { +class DeConvolutionCPUKernel : public ConvolutionBaseCPUKernel { + public: + DeConvolutionCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~DeConvolutionCPUKernel() override; + int Init() override; + int Run() override; + int ReSize() override; + + public: + int DoDeconv(int task_id); + int DoPostFunc(int task_id); + + private: + int InitParam(); + int InitWeightBias(); + + private: + StrassenMatMulParameter *matmul_param_; + int thread_hw_count_; + int thread_hw_stride_; + int thread_co4_count_; + int thread_co_stride_; + float *weight_ptr_; + float *tmp_buffer_; + float *tmp_output_; + float *c4_input_; + float *c4_output_; + float *input_ptr_; + float *output_ptr_; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc new file mode 100644 index 00000000000..e6963ee403c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc @@ -0,0 +1,162 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/deconvolution_depthwise.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DepthwiseConv2D; + +namespace mindspore::kernel { +int DeconvolutionDepthwiseCPUKernel::InitSlideParam() { + conv_param_->input_batch_ = outputs_.front()->shape().at(kNHWC_N); + conv_param_->input_h_ = outputs_.front()->shape().at(kNHWC_H); + conv_param_->input_w_ = outputs_.front()->shape().at(kNHWC_W); + conv_param_->input_channel_ = outputs_.front()->shape().at(kNHWC_C); + conv_param_->output_batch_ = inputs_.front()->shape().at(kNHWC_N); + conv_param_->output_h_ = inputs_.front()->shape().at(kNHWC_H); + conv_param_->output_w_ = inputs_.front()->shape().at(kNHWC_W); + conv_param_->output_channel_ = inputs_.front()->shape().at(kNHWC_C); + + // init sliding window param + sliding = new SlidingWindowParam; + InitSlidingParam(sliding, conv_param_, C4NUM); + return RET_OK; +} + +int DeconvolutionDepthwiseCPUKernel::Init() { + InitSlideParam(); + // conv base init + ConvolutionBaseCPUKernel::Init(); + + // pack input function: convert_func_ + auto input_tensor = inputs_[kInputIndex]; + auto data_type = input_tensor->data_type(); + auto input_format = input_tensor->GetFormat(); + schema::Format execute_format = schema::Format_NHWC4; + if (input_format != execute_format) { + convert_func_ = LayoutTransform(data_type, input_format, execute_format); + if (convert_func_ == nullptr) { + MS_LOG(ERROR) << "layout convert func is nullptr."; + return RET_ERROR; + } + } + + // init weight: o, h, w, i; o == group, i == 1 + auto weight_tensor = inputs_[kWeightIndex]; + auto origin_weight = reinterpret_cast(weight_tensor->Data()); + int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int pack_weight_size = C4NUM * OC4 * conv_param_->kernel_h_ * conv_param_->kernel_w_; + + packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float))); + memset(packed_weight_, 0, pack_weight_size * sizeof(float)); + PackNCHWToNC4HW4Fp32(origin_weight, packed_weight_, 1, conv_param_->kernel_h_ * conv_param_->kernel_w_, + conv_param_->output_channel_); + + // init bias + bias_data_ = reinterpret_cast(malloc(C4NUM * OC4 * sizeof(float))); + memset(bias_data_, 0, C4NUM * OC4 * sizeof(float)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias, conv_param_->output_channel_ * sizeof(float)); + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + + // init threadNum; + conv_param_->thread_num_ = MSMIN(conv_param_->thread_num_, OC4); + ReSize(); + return RET_OK; +} + +int DeconvolutionDepthwiseCPUKernel::ReSize() { + // malloc pack input buffer + if (convert_func_ != nullptr) { + int IC4 = UP_DIV(conv_param_->input_channel_, C4NUM); + int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * IC4; + packed_input_ = reinterpret_cast(malloc(pack_input_size * sizeof(float))); + memset(packed_input_, 0, pack_input_size * sizeof(float)); + } + + // malloc tmp output buffer + if (conv_param_->output_channel_ % C4NUM != 0) { + need_pack_ = true; + int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * OC4; + packed_output_ = reinterpret_cast(malloc(pack_output_size * sizeof(float))); + memset(packed_output_, 0, pack_output_size * sizeof(float)); + } + return RET_OK; +} + +int DeconvolutionDepthwiseCPUKernel::DoExcute(int task_id) { + DeconvDwC4Fp32(packed_output_, packed_input_, packed_weight_, reinterpret_cast(bias_data_), conv_param_, + sliding, task_id); + return RET_OK; +} + +int DeconvDwRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv_dw = reinterpret_cast(cdata); + auto ret = conv_dw->DoExcute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DeconvolutionDepthwiseRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int DeconvolutionDepthwiseCPUKernel::Run() { + if (conv_param_->input_channel_ != conv_param_->output_channel_) { + MS_LOG(ERROR) << "Only support input channel equals output channel."; + return RET_ERROR; + } + auto input_tensor = inputs_.at(kInputIndex); + auto input_addr = reinterpret_cast(input_tensor->Data()); + + // pack input: to nhwc4 + if (convert_func_ != nullptr) { + convert_func_(input_addr, packed_input_, conv_param_->input_batch_, conv_param_->input_h_ * conv_param_->input_w_, + conv_param_->input_channel_); + } else { + packed_input_ = input_addr; + } + + output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + memset(output_addr, 0, outputs_.at(kOutputIndex)->ElementsNum() * sizeof(float)); + if (!need_pack_) { + packed_output_ = output_addr; + } + + auto ret = LiteBackendParallelLaunch(DeconvDwRun, this, conv_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvDwRun error: error_code[" << ret << "]"; + return RET_ERROR; + } + + if (need_pack_) { + PackNHWC4ToNHWCFp32(packed_output_, output_addr, conv_param_->output_batch_, + conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); + } + return RET_OK; +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.h b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.h new file mode 100644 index 00000000000..ec612584fc2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_DEPTHWISE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_DEPTHWISE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/opclib/fp32/conv_depthwise.h" + +namespace mindspore::kernel { +class DeconvolutionDepthwiseCPUKernel : public ConvolutionBaseCPUKernel { + public: + DeconvolutionDepthwiseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~DeconvolutionDepthwiseCPUKernel() override { + delete sliding; + free(packed_weight_); + free(packed_input_); + free(packed_output_); + }; + + int Init() override; + int InitSlideParam(); + int ReSize() override; + int Run() override; + + int DoExcute(int task_id); + + private: + SlidingWindowParam *sliding; + float *packed_weight_; + float *packed_input_; + float *packed_output_; + float *output_addr; + bool need_pack_ = false; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_DEPTHWISE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc new file mode 100644 index 00000000000..a96c598b0b6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/depth_to_space.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/opclib/fp32/depth_to_space.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_FORMAT_ERR; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_PARAM_INVALID; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DepthToSpace; + +namespace mindspore::kernel { + +int DepthToSpaceCPUKernel::Init() { + if (inputs_[0]->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "depth_to_space only support NHWC now!"; + return RET_FORMAT_ERR; + } + DepthToSpaceParameter *param = reinterpret_cast(opParameter); + if (param->block_size_ <= 0) { + MS_LOG(ERROR) << "Input block_size should > 0!"; + return RET_PARAM_INVALID; + } + return RET_OK; +} + +int DepthToSpaceCPUKernel::Run() { + auto input = inputs_[0]; + auto output = outputs_[0]; + const float *input_data = reinterpret_cast(input->Data()); + float *output_data = reinterpret_cast(output->Data()); + auto in_shape = input->shape(); + auto out_shape = output->shape(); + DepthToSpaceParameter *param = reinterpret_cast(opParameter); + if (input->GetFormat() == schema::Format_NHWC) { + DepthToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape.data(), in_shape.size(), + param->block_size_); + return RET_OK; + } else { + MS_LOG(ERROR) << "Only support NHWC now!"; + return RET_ERROR; + } +} +kernel::LiteKernel *CpuDepthToSpaceFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) DepthToSpaceCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new DepthToSpaceCPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_DepthToSpace, CpuDepthToSpaceFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.h b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.h new file mode 100644 index 00000000000..88638c148d5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DEPTH_TO_SPACE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DEPTH_TO_SPACE_H_ + +#include +#include "src/lite_kernel.h" + + +namespace mindspore::kernel { +class DepthToSpaceCPUKernel : public LiteKernel { + public: + DepthToSpaceCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) : LiteKernel(parameter, inputs, outputs) {} + ~DepthToSpaceCPUKernel() = default; + + int Init() override; + int ReSize() override { + return 0; + } + int Run() override; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DEPTH_TO_SPACE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.cc new file mode 100644 index 00000000000..177e8d8e147 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.cc @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/expandDims.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_ExpandDims; + +namespace mindspore::kernel { +int ExpandDimsCPUKernel::Init() { + int ret = ReSize(); + return ret; +} + +int ExpandDimsCPUKernel::ReSize() { + data_size_ = inputs_.at(0)->ElementsNum(); + thread_sz_count_ = MSMIN(thread_count_, data_size_); + thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); + return RET_OK; +} + +int ExpandDimsCPUKernel::DoExpandDims(int task_id) { + size_t size = MSMIN(thread_sz_stride_, data_size_ - task_id * thread_sz_stride_); + if (size == 0) { + return RET_OK; + } + int offset = task_id * thread_sz_stride_; + int ret = ExpandDims(in_ptr_ + offset, out_ptr_ + offset, size * sizeof(float)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ExpandDimsRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int ExpandDimsRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->DoExpandDims(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ExpandDimsRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int ExpandDimsCPUKernel::Run() { + in_ptr_ = reinterpret_cast(inputs_.at(0)->Data()); + out_ptr_ = reinterpret_cast(outputs_.at(0)->Data()); + int ret = LiteBackendParallelLaunch(ExpandDimsRun, this, thread_sz_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ExpandDimsRun error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +kernel::LiteKernel *CpuExpandsDimsFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_ExpandDims); + auto *kernel = new (std::nothrow) ExpandDimsCPUKernel(opParameter, inputs, outputs, ctx); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_ExpandDims, CpuExpandsDimsFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.h b/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.h new file mode 100644 index 00000000000..fc47ac2bc7e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EXPANDDIMS_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_EXPANDDIMS_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/fp32/expandDims.h" +#include "schema/model_generated.h" + +#include "include/context.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class ExpandDimsCPUKernel : public LiteKernel { + public: + ExpandDimsCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) {} + ~ExpandDimsCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoExpandDims(int task_id); + + private: + int thread_count_; + int thread_sz_count_; + int thread_sz_stride_; + size_t data_size_; + float *in_ptr_; + float *out_ptr_; + const Context *ctx_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_ARM_FP32_EXPANDDIMS_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fill.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fill.cc new file mode 100644 index 00000000000..05b69497cf5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fill.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/fill.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Fill; + +namespace mindspore::kernel { + +namespace { +constexpr int kInputNum = 1; +constexpr int kOutputNum = 1; +} // namespace + +int FillCPUKernel::Init() { + data_size_ = outputs_.front()->ElementsNum(); + thread_sz_count_ = MSMIN(thread_count_, data_size_); + thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); + return RET_OK; +} + +int FillCPUKernel::ReSize() { return RET_OK; } + +int FillCPUKernel::DoFill(int task_id) { + int size = MSMIN(thread_sz_stride_, data_size_ - task_id * thread_sz_stride_); + if (size <= 0) { + return RET_OK; + } + int offset = task_id * thread_sz_stride_; + int ret = Fill(out_ptr_ + offset, size, src_data_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "FillRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int FillRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->DoFill(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "FillRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int FillCPUKernel::Run() { + auto fillData = inputs_.at(inputs_.size() - 1); + auto output = outputs_.front(); + auto fill_data = reinterpret_cast(fillData->Data()); + src_data_ = fill_data[0]; + out_ptr_ = reinterpret_cast(output->Data()); + int ret = LiteBackendParallelLaunch(FillRun, this, thread_sz_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "FillRun error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +kernel::LiteKernel *CpuFillFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, opParameter is nullptr, type: PrimitiveType_Fill. "; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Fill); + auto *kernel = new (std::nothrow) FillCPUKernel(opParameter, inputs, outputs, ctx); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Fill, CpuFillFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fill.h b/mindspore/lite/src/runtime/kernel/arm/fp32/fill.h new file mode 100644 index 00000000000..775203aa932 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fill.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FILL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FILL_H_ + +#include +#include "src/lite_kernel.h" + +#include "include/context.h" +#include "src/runtime/kernel/arm/opclib/fp32/fill.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class FillCPUKernel : public LiteKernel { + public: + FillCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) {} + ~FillCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoFill(int task_id); + + private: + int thread_count_; + int thread_sz_count_; + int thread_sz_stride_; + int data_size_; + float src_data_; + float *out_ptr_; + const Context *ctx_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FILL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/flatten.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/flatten.cc new file mode 100644 index 00000000000..6813ea8c40d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/flatten.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/flatten.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/opclib/flatten.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Flatten; + +namespace mindspore::kernel { +int FlattenCPUKernel::Init() { + auto output_shape = outputs_[0]->shape(); + flatten_param_->size = sizeof(float); + for (int i = 0; i < output_shape.size(); i++) { + flatten_param_->size *= output_shape[i]; + } + return RET_OK; +} + +int FlattenCPUKernel::ReSize() { return RET_OK; } + +int FlattenCPUKernel::Run() { + auto input = reinterpret_cast(inputs_[0]->Data()); + auto output = reinterpret_cast(outputs_[0]->Data()); + Flatten(input, output, flatten_param_); + return RET_OK; +} + +kernel::LiteKernel *CpuFlattenFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, opParameter is nullptr, type: PrimitiveType_Flatten. "; + return nullptr; + } + MS_ASSERT(desc.type == schema::PrimitiveType_Flatten); + auto *kernel = new (std::nothrow) FlattenCPUKernel(opParameter, inputs, outputs, ctx); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Flatten, CpuFlattenFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/flatten.h b/mindspore/lite/src/runtime/kernel/arm/fp32/flatten.h new file mode 100644 index 00000000000..3b2db2523d9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/flatten.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FLATTEN_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FLATTEN_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/opclib/flatten.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class FlattenCPUKernel : public LiteKernel { + public: + FlattenCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs) { + flatten_param_ = reinterpret_cast(parameter); + } + ~FlattenCPUKernel() override { delete flatten_param_; } + + int Init() override; + int ReSize() override; + int Run() override; + + private: + FlattenParameter *flatten_param_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FLATTEN_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc new file mode 100644 index 00000000000..80dce75f855 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc @@ -0,0 +1,118 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/fullconnection.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_MEMORY_FAILED; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +FullconnectionCPUKernel::~FullconnectionCPUKernel() { + if (a_c8_ptr_ != nullptr) { + free(a_c8_ptr_); + a_c8_ptr_ = nullptr; + } + if (b_r8_ptr_ != nullptr) { + free(b_r8_ptr_); + b_r8_ptr_ = nullptr; + } + if (c_r8x8_ptr_ != nullptr) { + free(c_r8x8_ptr_); + c_r8x8_ptr_ = nullptr; + } + if (bias_ptr_ != nullptr) { + free(bias_ptr_); + bias_ptr_ = nullptr; + } +} + +int FullconnectionCPUKernel::ReSize() { return RET_OK; } + +int FullconnectionCPUKernel::Init() { + fc_param_->row_ = (inputs_[0]->shape())[0]; + fc_param_->col_ = (inputs_[1]->shape())[0]; + fc_param_->deep_ = (inputs_[1]->shape())[1]; + + fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8); + fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8); + + thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8)); + thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_); + + bias_ptr_ = reinterpret_cast(malloc(fc_param_->col_8_ * sizeof(float))); + memset(bias_ptr_, 0, fc_param_->col_8_ * sizeof(float)); + if (inputs_.size() == 3) { + memcpy(bias_ptr_, inputs_[2]->Data(), fc_param_->col_ * sizeof(float)); + } + + a_c8_ptr_ = reinterpret_cast(malloc(fc_param_->row_8_ * fc_param_->deep_ * sizeof(float))); + if (a_c8_ptr_ == nullptr) { + return RET_MEMORY_FAILED; + } + memset(a_c8_ptr_, 0, fc_param_->row_8_ * fc_param_->deep_ * sizeof(float)); + + b_r8_ptr_ = reinterpret_cast(malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(float))); + if (b_r8_ptr_ == nullptr) { + return RET_MEMORY_FAILED; + } + memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(float)); + RowMajor2Col8Major(reinterpret_cast(inputs_[1]->Data()), b_r8_ptr_, fc_param_->col_, fc_param_->deep_); + + c_r8x8_ptr_ = reinterpret_cast(malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float))); + if (c_r8x8_ptr_ == nullptr) { + return RET_MEMORY_FAILED; + } + memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float)); + return RET_OK; +} + +int FcFp32MatmulRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto fc = reinterpret_cast(cdata); + auto error_code = fc->DoMatmul(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "FcFp32MatmulRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int FullconnectionCPUKernel::DoMatmul(int task_id) { + int cur_oc = MSMIN(thread_stride_, UP_DIV(fc_param_->col_8_, 8) - task_id * thread_stride_); + if (cur_oc <= 0) { + return RET_OK; + } + + MatMul(a_c8_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_, + c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_, + bias_ptr_ + task_id * thread_stride_ * C8NUM, fc_param_->maxf_, fc_param_->minf_, fc_param_->deep_, + fc_param_->row_8_, cur_oc * 8); + return RET_OK; +} + +int FullconnectionCPUKernel::Run() { + auto a_ptr = reinterpret_cast(inputs_.at(0)->Data()); + auto output_ptr = reinterpret_cast(outputs_.at(0)->Data()); + + RowMajor2Col8Major(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_); + + LiteBackendParallelLaunch(FcFp32MatmulRun, this, thread_count_); + + Row8x8Major2RowMajor(c_r8x8_ptr_, output_ptr, fc_param_->row_, fc_param_->col_); + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h new file mode 100644 index 00000000000..c654945f497 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FULLCONNECTION_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FULLCONNECTION_H_ + +#include +#include "include/errorcode.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/opclib/fp32/matmul.h" +#include "src/runtime/kernel/arm/base/fullconnection_base.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class FullconnectionCPUKernel : public FullconnectionBaseCPUKernel { + public: + FullconnectionCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : FullconnectionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~FullconnectionCPUKernel() override; + + int Init() override; + int ReSize() override; + int Run() override; + + public: + int DoMatmul(int task_id); + + private: + float *a_c8_ptr_; + float *b_r8_ptr_; + float *c_r8x8_ptr_; + float *bias_ptr_; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FULLCONNECTION_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.cc new file mode 100644 index 00000000000..c54bf2ce85c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/fused_batchnorm.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_FusedBatchNorm; + +namespace mindspore::kernel { +int FusedBatchnormCPUKernel::Init() { + input_shape_ = reinterpret_cast(malloc(sizeof(int) * inputs_[0]->shape().size())); + memcpy(input_shape_, inputs_[0]->shape().data(), inputs_[0]->shape().size() * sizeof(int)); + return RET_OK; +} + +int FusedBatchnormCPUKernel::ReSize() { return RET_OK; } + +int FusedBatchnormCPUKernel::Run() { + auto input_addr = reinterpret_cast(inputs_.at(0)->Data()); + auto scale_addr = reinterpret_cast(inputs_.at(1)->Data()); + auto offest_addr = reinterpret_cast(inputs_.at(2)->Data()); + auto mean_addr = reinterpret_cast(inputs_.at(3)->Data()); + auto variance_addr = reinterpret_cast(inputs_.at(4)->Data()); + auto output_addr = reinterpret_cast(outputs_.at(0)->Data()); + + FusedBatchNorm(input_addr, scale_addr, offest_addr, mean_addr, variance_addr, input_shape_, + fused_batchnorm_param_->epsilon_, output_addr); + return RET_OK; +} + +kernel::LiteKernel *CpuFusedBatchnormKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_FusedBatchNorm); + auto *kernel = new (std::nothrow) FusedBatchnormCPUKernel(opParameter, inputs, outputs); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_FusedBatchNorm, CpuFusedBatchnormKernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.h b/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.h new file mode 100644 index 00000000000..0ab0e2e8a97 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FUSED_BATCHNORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FUSED_BATCHNORM_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/fused_batchnorm.h" + +namespace mindspore::kernel { +class FusedBatchnormCPUKernel : public LiteKernel { + public: + FusedBatchnormCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) { + fused_batchnorm_param_ = reinterpret_cast(parameter); + } + ~FusedBatchnormCPUKernel() override { delete fused_batchnorm_param_; } + + int Init() override; + int ReSize() override; + int Run() override; + + private: + int *input_shape_{}; + FusedBatchNormParameter *fused_batchnorm_param_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FUSED_BATCHNORM_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc new file mode 100644 index 00000000000..dd0d899f8f1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc @@ -0,0 +1,126 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/runtime/kernel/arm/fp32/gather.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Gather; + +namespace mindspore::kernel { + +int GatherCPUKernel::Init() { + axis_ = (reinterpret_cast(opParameter))->axis_; + batchDims_ = (reinterpret_cast(opParameter))->batchDims_; + return RET_OK; +} + +int GatherCPUKernel::ReSize() { return RET_OK; } + +int GatherCPUKernel::DoGather(int task_id) { + auto input_tensor = inputs_.at(0); + auto indices_tensor = inputs_.at(1); + auto out_tensor = outputs_.at(0); + + auto input_ptr = reinterpret_cast(input_tensor->Data()); + auto indices_ptr = reinterpret_cast(indices_tensor->Data()); + auto output_ptr = reinterpret_cast(out_tensor->Data()); + + auto in_shape = input_tensor->shape(); + int in_rank = in_shape.size(); + int indices_element_size = indices_tensor->ElementsNum(); + + const int limit = in_shape[axis_]; + for (size_t i = 0; i < indices_element_size; ++i) { + if (indices_ptr[i] >= limit) { + MS_LOG(ERROR) << " indice data: " << indices_ptr[i] << " is not in [ 0, " << limit - 1 << " ]"; + return RET_ERROR; + } + } + + int outer_size = 1; + for (int i = 0; i < axis_; ++i) { + outer_size *= in_shape[i]; + } + + int inner_size = 1; + for (int i = axis_ + 1; i < in_rank; ++i) { + inner_size *= in_shape[i]; + } + + int stride = UP_DIV(outer_size, thread_count_); + int count = MSMIN(stride, outer_size - stride * task_id); + + input_ptr += stride * task_id * limit; + output_ptr += stride * task_id * indices_element_size; + + auto error_code = Gather(input_ptr, count, inner_size, limit, indices_ptr, indices_element_size, output_ptr); + if (error_code != RET_OK) { + return RET_ERROR; + } + return RET_OK; +} + +int GatherRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto gather_kernel = reinterpret_cast(cdata); + auto error_code = gather_kernel->DoGather(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "GatherRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int GatherCPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(GatherRun, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Gather function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuGatherFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Gather); + + auto *kernel = new (std::nothrow) GatherCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Gather, CpuGatherFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h new file mode 100644 index 00000000000..f2f8f25fe65 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GATHER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GATHER_H_ + +#include +#include "src/runtime/kernel/arm/opclib/fp32/gather.h" +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class GatherCPUKernel : public LiteKernel { + public: + GatherCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->threadNum) {} + ~GatherCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoGather(int task_id); + + private: + int thread_count_; + int batchDims_; + int axis_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GATHER_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.cc new file mode 100644 index 00000000000..a9e02bab125 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.cc @@ -0,0 +1,148 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/gatherNd.h" +#include +#include +#include "schema/model_generated.h" +#include "include/errorcode.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_GatherNd; + +namespace mindspore::kernel { + +GatherNdCPUKernel::~GatherNdCPUKernel() { + if (in_offset_ != nullptr) { + free(in_offset_); + in_offset_ = nullptr; + } +} + +int GatherNdCPUKernel::Init() { + auto indices_tensor = inputs_.at(1); + auto indices_shape = indices_tensor->shape(); + int indices_rank = indices_shape.size(); + count_ = 1; + for (int i = 0; i < indices_rank - 1; ++i) { + count_ *= indices_shape[i]; + } + + in_offset_ = reinterpret_cast(malloc(count_ * sizeof(int))); + if (in_offset_ == nullptr) { + MS_LOG(ERROR) << "GatherNd Malloc in_offset_ error!"; + return RET_ERROR; + } + (void)memset(in_offset_, 0, count_ * sizeof(int)); + + thread_sz_count_ = MSMIN(thread_count_, count_); + thread_sz_stride_ = UP_DIV(count_, thread_sz_count_); + int ret = ReSize(); + return ret; +} + +int GatherNdCPUKernel::ReSize() { + auto in_shape = inputs_.front()->shape(); + int in_rank = in_shape.size(); + auto indices_tensor = inputs_.at(1); + auto indices_shape = indices_tensor->shape(); + int indices_rank = indices_shape.size(); + int idx_lastshape = indices_shape[indices_rank - 1]; + auto indices_ptr = reinterpret_cast(indices_tensor->Data()); + area_ = 1; + for (int i = idx_lastshape; i < in_rank; ++i) { + area_ *= in_shape[i]; + } + std::vector in_stride(in_rank); + in_stride[in_rank - 1] = 1; + for (int i = in_rank - 2; i >= 0; --i) { + in_stride[i] = in_shape[i + 1] * in_stride[i + 1]; + } + + int idx_stride = idx_lastshape; + for (int j = 0; j < count_; ++j) { + for (int k = 0; k < idx_lastshape; ++k) { + in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride[k]; + } + } + + return RET_OK; +} + +int GatherNdCPUKernel::DoGatherNd(int task_id) { + int count = MSMIN(thread_sz_stride_, count_ - task_id * thread_sz_stride_); + if (count <= 0) { + return RET_OK; + } + int offset = task_id * thread_sz_stride_; + auto ret = GatherNd(in_ptr_, out_ptr_ + offset * area_, in_offset_ + offset, area_, count); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GatherNdRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int GatherNdRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->DoGatherNd(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GatherNdRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int GatherNdCPUKernel::Run() { + in_ptr_ = reinterpret_cast(inputs_.front()->Data()); + out_ptr_ = reinterpret_cast(outputs_.front()->Data()); + int ret = LiteBackendParallelLaunch(GatherNdRun, this, thread_sz_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +kernel::LiteKernel *CpuGatherNdFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_GatherNd); + + auto *kernel = new (std::nothrow) GatherNdCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_GatherNd, CpuGatherNdFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.h b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.h new file mode 100644 index 00000000000..139da4ddafe --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GATHERND_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GATHERND_H_ + +#include +#include "src/runtime/kernel/arm/opclib/fp32/gatherNd.h" +#include "src/lite_kernel.h" + +#include "include/context.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class GatherNdCPUKernel : public LiteKernel { + public: + GatherNdCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) {} + ~GatherNdCPUKernel() override; + + int Init() override; + int ReSize() override; + int Run() override; + int DoGatherNd(int task_id); + + private: + int thread_count_; + int thread_sz_count_; + int thread_sz_stride_; + int count_; + int area_; + int *in_offset_ = nullptr; + float *in_ptr_; + float *out_ptr_; + const Context *ctx_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GATHERND_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.cc new file mode 100644 index 00000000000..575dad68978 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.cc @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/local_response_norm.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_LocalResponseNormalization; + +namespace mindspore::kernel { + +int LocalResponseNormCPUKernel::Init() { + depth_radius_ = (reinterpret_cast(opParameter))->depth_radius_; + bias_ = (reinterpret_cast(opParameter))->bias_; + alpha_ = (reinterpret_cast(opParameter))->alpha_; + beta_ = (reinterpret_cast(opParameter))->beta_; + return RET_OK; +} + +int LocalResponseNormCPUKernel::ReSize() { return RET_OK; } + +int LocalResponseNormCPUKernel::DoLocalResponseNorm(int task_id) { + auto input_tensor = inputs_.front(); + auto out_tensor = outputs_.front(); + auto input_ptr = reinterpret_cast(input_tensor->Data()); + auto output_ptr = reinterpret_cast(out_tensor->Data()); + + auto in_shape = input_tensor->shape(); + MS_ASSERT(in_shape.size() == 4); + + int batch = in_shape[0]; + int height = in_shape[1]; + int width = in_shape[2]; + int channel = in_shape[3]; + + int outer_size = batch * width * height; + int stride = UP_DIV(outer_size, thread_count_); + int count = MSMIN(stride, outer_size - stride * task_id); + + input_ptr += stride * task_id * channel; + output_ptr += stride * task_id * channel; + + auto error_code = LocalResponseNorm(input_ptr, count, channel, output_ptr, depth_radius_, bias_, alpha_, beta_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "DoLocalResponseNorm error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int LocalResponseNormRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto lrn = reinterpret_cast(cdata); + auto error_code = lrn->DoLocalResponseNorm(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "LocalResponseNormRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int LocalResponseNormCPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(LocalResponseNormRun, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "LocalResponseNorm function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuLocalResponseNormFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_LocalResponseNormalization); + + auto *kernel = new (std::nothrow) LocalResponseNormCPUKernel(opParameter, inputs, outputs, ctx); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_LocalResponseNormalization, CpuLocalResponseNormFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.h b/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.h new file mode 100644 index 00000000000..ac5e89f3cbd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LOCAL_RESPONSE_NORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LOCAL_RESPONSE_NORM_H_ + +#include +#include "src/runtime/kernel/arm/opclib/fp32/local_response_norm.h" +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class LocalResponseNormCPUKernel : public LiteKernel { + public: + LocalResponseNormCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->threadNum) {} + ~LocalResponseNormCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoLocalResponseNorm(int task_id); + + private: + int thread_count_; + int depth_radius_; + float bias_; + float alpha_; + float beta_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LOCAL_RESPONSE_NORM_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc new file mode 100644 index 00000000000..b1ec51233dc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/matmul.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_MatMul; + +namespace mindspore::kernel { + +int MatmulCPUKernel::ReSize() { return RET_OK; } + +int MatmulCPUKernel::Run() { return RET_OK; } + +int MatmulCPUKernel::Init() { return RET_OK; } + +kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_MatMul); + auto *kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_MatMul, CpuMatmulFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h new file mode 100644 index 00000000000..3dfc4521eb4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/opclib/matmul.h" + +namespace mindspore::kernel { +class MatmulCPUKernel : public LiteKernel { + public: + explicit MatmulCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) { + matmul_param_ = reinterpret_cast(parameter); + } + ~MatmulCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + MatMulParameter *matmul_param_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc new file mode 100644 index 00000000000..38c10441ae3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/nchw2nhwc.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Nchw2Nhwc; + +namespace mindspore::kernel { +int Nchw2NhwcCPUKernel::Init() { return RET_OK; } + +int Nchw2NhwcCPUKernel::ReSize() { return RET_OK; } + +int Nchw2NhwcCPUKernel::Run() { + auto input = inputs_[0]; + auto output = outputs_[0]; + + PackNCHWToNHWCFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(), + output->Channel()); + return RET_OK; +} + +kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Nchw2Nhwc); + auto *kernel = new (std::nothrow) Nchw2NhwcCPUKernel(opParameter, inputs, outputs); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.h b/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.h new file mode 100644 index 00000000000..1a5fe9c0648 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_ + +#include +#include "src/lite_kernel.h" + +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/opclib/pack.h" + +namespace mindspore::kernel { +class Nchw2NhwcCPUKernel : public LiteKernel { + public: + Nchw2NhwcCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~Nchw2NhwcCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.cc new file mode 100644 index 00000000000..b7441153409 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/nhwc2nchw.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Nhwc2Nchw; + +namespace mindspore::kernel { +int Nhwc2NchwCPUKernel::Init() { return RET_OK; } + +int Nhwc2NchwCPUKernel::ReSize() { return RET_OK; } + +int Nhwc2NchwCPUKernel::Run() { + auto input = inputs_[0]; + auto output = outputs_[0]; + + PackNHWCToNCHWFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(), + output->Channel()); + return RET_OK; +} + +kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Nhwc2Nchw); + auto *kernel = new (std::nothrow) Nhwc2NchwCPUKernel(opParameter, inputs, outputs); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.h b/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.h new file mode 100644 index 00000000000..b2de852b9ae --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_ + +#include +#include "src/lite_kernel.h" + +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/opclib/pack.h" + +namespace mindspore::kernel { +class Nhwc2NchwCPUKernel : public LiteKernel { + public: + Nhwc2NchwCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~Nhwc2NchwCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc new file mode 100644 index 00000000000..3b011dfbf48 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc @@ -0,0 +1,187 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/one_hot.h" +#include "src/runtime/kernel/arm/opclib/fp32/one_hot.h" +#include "schema/model_generated.h" +#include "src/runtime/runtime_api.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_NULL_PTR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_OneHot; + +namespace mindspore::kernel { +namespace { +constexpr size_t kInputNum = 4; +constexpr size_t kOutputNum = 1; +} // namespace + +int OneHotCPUKernel::Init() { + // indices depth on_value off_value + if (inputs_.size() != kInputNum || outputs_.size() != kOutputNum) { + MS_LOG(ERROR) << "OneHot input size should be " << kInputNum << ", got " << inputs_.size() + << ", output size should be" << kOutputNum << ", got " << outputs_.size(); + return RET_ERROR; + } + + auto indices = inputs_.at(0); + if (indices == nullptr) { + MS_LOG(ERROR) << "OneHot inputs[0] indices nullptr"; + return RET_NULL_PTR; + } + auto indices_shape = indices->shape(); + outer_size_ = 1; + for (size_t i = 0; i < static_cast(axis_); i++) { + outer_size_ *= indices_shape[i]; + } + inner_size_ = indices->ElementsNum() / outer_size_; + + if (context_ == nullptr) { + MS_LOG(ERROR) << "OneHot context nullptr"; + return RET_NULL_PTR; + } + thread_num_ = context_->threadNum; + + const int indices_rank = static_cast(inputs_.at(0)->shape().size()); + if (axis_ < 0) { + axis_ += indices_rank + 1; + } + + return RET_OK; +} + +int RunOneHot(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto onehot_kernel = reinterpret_cast(cdata); + if (onehot_kernel == nullptr) { + MS_LOG(ERROR) << "cast OneHotCPUKernel failed"; + return RET_ERROR; + } + auto error_code = onehot_kernel->OneHotImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "RunOneHot error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int OneHotCPUKernel::OneHotImpl(int task_id) { + auto indices_data = static_cast(inputs_.at(0)->Data()); + auto output = outputs_.at(0); + if (output == nullptr) { + MS_LOG(ERROR) << "OneHot output nullptr"; + return RET_NULL_PTR; + } + auto output_data = static_cast(output->Data()); + + auto ret = GetParams(); + if (ret != RET_OK) { + return ret; + } + auto one_hot_param = reinterpret_cast(opParameter); + + ret = OneHot(indices_data, output_data, one_hot_param, task_id, thread_num_); + return ret; +} + +int OneHotCPUKernel::GetParams() { + auto one_hot_param = reinterpret_cast(opParameter); + if (one_hot_param == nullptr) { + MS_LOG(ERROR) << "cast OneHotParameter nullptr"; + return RET_NULL_PTR; + } + + auto depth_tensor = inputs_.at(1); + if (depth_tensor == nullptr) { + MS_LOG(ERROR) << "OneHot inputs[1] depth nullptr"; + return RET_NULL_PTR; + } + const int *depth = static_cast(depth_tensor->Data()); + if (depth == nullptr) { + return RET_NULL_PTR; + } + one_hot_param->depth_ = *depth; + + auto on_value_tensor = inputs_.at(2); + if (on_value_tensor == nullptr) { + MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr"; + return RET_NULL_PTR; + } + const float *on_value = static_cast(on_value_tensor->Data()); + if (on_value == nullptr) { + return RET_NULL_PTR; + } + one_hot_param->on_value_ = *on_value; + + auto off_value_tensor = inputs_.at(3); + if (off_value_tensor == nullptr) { + MS_LOG(ERROR) << "OneHot inputs[3] off_value nullptr"; + return RET_NULL_PTR; + } + const float *off_value = static_cast(off_value_tensor->Data()); + if (off_value == nullptr) { + return RET_NULL_PTR; + } + one_hot_param->off_value_ = *off_value; + + one_hot_param->outer_size_ = outer_size_; + one_hot_param->inner_size_ = inner_size_; + + return RET_OK; +} + +int OneHotCPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(RunOneHot, this, context_->threadNum); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "OneHot function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuOneHotFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter != nullptr) { + MS_LOG(ERROR) << "OneHot opParameter nullptr."; + return nullptr; + } + if (desc.type != schema::PrimitiveType_OneHot) { + MS_LOG(ERROR) << "OneHot desc type should be " << schema::PrimitiveType_OneHot << " got " << desc.type; + return nullptr; + } + auto *kernel = new (std::nothrow) OneHotCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "OneHot new kernel failed."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_OneHot, CpuOneHotFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.h b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.h new file mode 100644 index 00000000000..dd823f44600 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ONE_HOT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ONE_HOT_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" + +namespace mindspore::kernel { +class OneHotCPUKernel : public LiteKernel { + public: + OneHotCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), context_(ctx) {} + + ~OneHotCPUKernel() override = default; + + int Init() override; + int ReSize() override { return 0; }; + int Run() override; + int OneHotImpl(int task_id); + + private: + int GetParams(); + + private: + const lite::Context *context_; + int thread_num_; + int axis_; + int outer_size_; + int inner_size_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ONE_HOT_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc new file mode 100644 index 00000000000..206b50c0605 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc @@ -0,0 +1,187 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/fp32/pad.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/opclib/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_NULL_PTR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Pad; + +namespace mindspore::kernel { +namespace { +constexpr int kInputNum = 1; +constexpr int kOutputNum = 1; +constexpr int kInputRank = 4; +constexpr int kPaddingsSize = 8; +} // namespace + +int PadCPUKernel::CheckInputsOutputsParams() { + if (inputs_.size() != kInputNum || outputs_.size() != kOutputNum) { + MS_LOG(ERROR) << "Pad input size should be " << kInputNum << ", got " << inputs_.size() << ", output size should be" + << kOutputNum << ", got " << outputs_.size(); + return RET_ERROR; + } + + auto input = inputs_.at(0); + auto output = outputs_.at(0); + if (input == nullptr || output == nullptr) { + MS_LOG(ERROR) << "Pad input or output nullptr"; + return RET_NULL_PTR; + } + + auto rank = input->shape().size(); + if (rank != kInputRank) { + MS_LOG(ERROR) << "Pad input rank should be " << kInputRank << ", got " << rank; + return RET_ERROR; + } + + if (paddings_size_ != kPaddingsSize) { + MS_LOG(ERROR) << "Pad op paddings size should be 2*input_rank: " << 2 * rank << " but got " << paddings_size_; + return RET_ERROR; + } + + for (auto pad : paddings_) { + if (pad < 0) { + MS_LOG(ERROR) << "Pad op paddings should be >= 0, but got " << pad; + return RET_ERROR; + } + } + return RET_OK; +} + +int PadCPUKernel::MaybeConvertInputLayout() { + auto input = inputs_.at(0); + auto input_format = input->GetFormat(); + if (input_format != exec_format_) { + auto input_type = input->data_type(); + layout_convertor_ = LayoutTransform(input_type, input_format, exec_format_); + if (layout_convertor_ == nullptr) { + MS_LOG(ERROR) << "Pad lack layout convertor from " << input_format << " to " << exec_format_; + return RET_ERROR; + } + exec_input_data_ = reinterpret_cast(malloc(input->DataSize() * sizeof(float))); + if (exec_input_data_ == nullptr) { + MS_LOG(ERROR) << "Pad malloc failed."; + return RET_ERROR; + } + } + return RET_OK; +} + +int PadCPUKernel::Init() { + auto ret = CheckInputsOutputsParams(); + if (ret != RET_OK) { + return ret; + } + + ret = MaybeConvertInputLayout(); + if (ret != RET_OK) { + return ret; + } + + auto output = outputs_.at(0); + output->SetFormat(exec_format_); + + return RET_OK; +} + +int PadImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto padKernel = reinterpret_cast(cdata); + int error_code = padKernel->RunImpl(task_id); + if (error_code != OPCLIB_OK) { + MS_LOG(ERROR) << "Pad Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int PadCPUKernel::RunImpl(int task_id) { + auto input = inputs_.at(0); + auto output = outputs_.at(0); + + auto input_data = reinterpret_cast(input->Data()); + auto output_data = reinterpret_cast(output->Data()); + auto input_shape = input->shape().data(); + auto output_shape = output->shape().data(); + if (exec_input_data_ != nullptr) { + Pad(exec_input_data_, output_data, input_shape, output_shape, paddings_.data(), task_id, context_->threadNum); + } else { + Pad(input_data, output_data, input_shape, output_shape, paddings_.data(), task_id, context_->threadNum); + } + + return RET_OK; +} + +int PadCPUKernel::Run() { + auto output = outputs_.at(0); + int output_size = output->DataSize(); + + auto output_data = reinterpret_cast(output->Data()); + // todo parallel memset to save time + memset(output_data, 0, output_size * sizeof(float)); + + auto input = inputs_.at(0); + if (exec_input_data_ != nullptr) { + if (layout_convertor_ == nullptr) { + return RET_NULL_PTR; + } + layout_convertor_(inputs_.at(0), exec_input_data_, input->Batch(), input->Height() * input->Width(), + input->Channel()); + } + + int error_code = LiteBackendParallelLaunch(PadImpl, this, context_->threadNum); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Pad run error, error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuPadFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Pad opParameter nullptr"; + return nullptr; + } + MS_ASSERT(desc.type == PrimitiveType_Pad); + auto *kernel = new (std::nothrow) PadCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new PadCPUKernel failed."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Pad, CpuPadFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pad.h b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.h new file mode 100644 index 00000000000..b9e7226956c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PAD_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/opclib/pad.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" + +namespace mindspore::kernel { +class PadCPUKernel : public LiteKernel { + public: + PadCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), context_(ctx) {} + + ~PadCPUKernel() { + if (exec_input_data_ != nullptr) { + free(exec_input_data_); + exec_input_data_ = nullptr; + } + } + + int Init() override; + int ReSize() override { return 0; }; + int Run() override; + int RunImpl(int task_id); + + private: + int CheckInputsOutputsParams(); + int MaybeConvertInputLayout(); + + private: + std::vector paddings_; + size_t paddings_size_; + const lite::Context *context_; + schema::Format exec_format_ = schema::Format_NHWC; + LayoutConvertor layout_convertor_ = nullptr; + float *exec_input_data_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PAD_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc new file mode 100644 index 00000000000..cee6090408e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/pooling.h" +#include "src/runtime/kernel/arm/opclib/fp32/pooling.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Pooling; + +namespace mindspore::kernel { +int PoolingCPUKernel::Init() { + auto ret = PoolingBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PoolingBase Init failed."; + return RET_ERROR; + } + return RET_OK; +} + +int PoolingCPUKernel::ReSize() { + auto ret = Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Pooling resize init failed."; + return RET_ERROR; + } + return RET_OK; +} + +int PoolingCPUKernel::RunImpl(int task_id) { + auto input_ptr = reinterpret_cast(inputs_.at(kInputIndex)->Data()); + auto output_ptr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + if (pooling_param_->max_pooling_) { + MaxPooling(input_ptr, output_ptr, pooling_param_, task_id); + } else { + AvgPooling(input_ptr, output_ptr, pooling_param_, task_id); + } + return RET_OK; +} + +int PoolingImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto pooling = reinterpret_cast(cdata); + auto error_code = pooling->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Pooling Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int PoolingCPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(PoolingImpl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.h b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.h new file mode 100644 index 00000000000..7edd82a5374 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_H_ + +#include +#include "src/runtime/kernel/arm/base/pooling_base.h" +#include "src/lite_kernel.h" +#include "ir/anf.h" +#include "include/context.h" + +namespace mindspore::kernel { +using mindspore::lite::Context; +using mindspore::schema::PadMode; +using mindspore::schema::PoolMode; +using mindspore::schema::QuantType; +using mindspore::schema::RoundMode; + +class PoolingCPUKernel : public PoolingBaseCPUKernel { + public: + PoolingCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : PoolingBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~PoolingCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc new file mode 100644 index 00000000000..f5ca09666c8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/power.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Power; + +namespace mindspore::kernel { +int PowerCPUKernel::Init() { return RET_OK; } + +int PowerCPUKernel::ReSize() { return RET_OK; } + +int PowerImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto kernel = reinterpret_cast(cdata); + auto ret = kernel->RunImpl(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PowerImpl error: " << ret; + return ret; + } + return RET_OK; +} + +int PowerCPUKernel::Run() { + int ret = LiteBackendParallelLaunch(PowerImpl, this, thread_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PowerCPUKernel error: " << ret; + return RET_ERROR; + } + return RET_OK; +} + +int PowerCPUKernel::RunImpl(int task_id) { + auto input_addr = reinterpret_cast(inputs_.at(0)->Data()); + auto output_addr = reinterpret_cast(outputs_.at(0)->Data()); + auto size = inputs_.at(0)->Size(); + int stride = UP_DIV(size, thread_count_); + int len = MSMIN(stride, size - stride * task_id); + + Power(input_addr + stride * task_id, output_addr + stride * task_id, len, power_, scale_, shift_); + return RET_OK; +} + +kernel::LiteKernel *CpuPowerFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Power); + auto *kernel = + new (std::nothrow) PowerCPUKernel(reinterpret_cast(opParameter), inputs, outputs, ctx); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Power, CpuPowerFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/power.h b/mindspore/lite/src/runtime/kernel/arm/fp32/power.h new file mode 100644 index 00000000000..0867a580dbf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/power.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/opclib/power.h" + +namespace mindspore::kernel { +class PowerCPUKernel : public LiteKernel { + public: + PowerCPUKernel(PowerParameter *param, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(reinterpret_cast(param), inputs, outputs), + thread_count_(ctx->threadNum), + power_(param->power_), + scale_(param->scale_), + shift_(param->shift_) {} + ~PowerCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + + private: + int thread_count_; + float power_; + float scale_; + float shift_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/prelu.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/prelu.cc new file mode 100644 index 00000000000..54e01955bf4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/prelu.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/prelu.h" +#include +#include "schema/model_generated.h" +#include "src/runtime/kernel/arm/opclib/prelu.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Prelu; + +namespace mindspore::kernel { +int PReluCPUKernel::Init() { + prelu_param_->op_parameter_.thread_num_ = thread_count_; + return RET_OK; +} + +int PReluCPUKernel::DoExcute(int task_id) { + PRelu(input_data, output_data, prelu_param_, task_id); + return RET_OK; +} + +int PReluRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto PReludata = reinterpret_cast(cdata); + auto ret = PReludata->DoExcute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PReluRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int PReluCPUKernel::Run() { + auto input = inputs_.at(0); + prelu_param_->input_num_ = input->ElementsNum(); + input_data = reinterpret_cast(input->Data()); + output_data = reinterpret_cast(outputs_.at(0)->Data()); + + auto ret = LiteBackendParallelLaunch(PReluRun, this, prelu_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PReluDwRun error: error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuPReluFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "input opParameter is nullptr!"; + return nullptr; + } + + auto *kernel = new (std::nothrow) PReluCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new PReluCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Prelu, CpuPReluFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/prelu.h b/mindspore/lite/src/runtime/kernel/arm/fp32/prelu.h new file mode 100644 index 00000000000..b9335b7ef94 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/prelu.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PRELU_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PRELU_H_ + +#include +#include "src/lite_kernel.h" + +#include "include/context.h" +#include "src/runtime/kernel/arm/opclib/prelu.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class PReluCPUKernel : public LiteKernel { + public: + PReluCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) { + prelu_param_ = (reinterpret_cast(opParameter)); + } + ~PReluCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override; + int DoExcute(int task_id); + + protected: + int thread_count_; + const Context *ctx_; + PReluParameter *prelu_param_; + + private: + float *input_data; + float *output_data; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PRELU_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/range.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/range.cc new file mode 100644 index 00000000000..b6fe744df9c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/range.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/range.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Range; + +namespace mindspore::kernel { + +namespace { +constexpr int kInputNum = 0; +constexpr int kOutputNum = 1; +} // namespace + +int RangeCPUKernel::Init() { return RET_OK; } + +int RangeCPUKernel::ReSize() { return RET_OK; } + +int RangeCPUKernel::Run() { + size_t start = (reinterpret_cast(opParameter))->start_; + size_t limit = (reinterpret_cast(opParameter))->limit_; + size_t delta = (reinterpret_cast(opParameter))->delta_; + auto output_ptr = reinterpret_cast(outputs_.at(0)->Data()); + Range(output_ptr, start, limit, delta); + return RET_OK; +} + +kernel::LiteKernel *CpuRangeFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Range); + + auto *kernel = new (std::nothrow) RangeCPUKernel(opParameter, inputs, outputs); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Range, CpuRangeFp32KernelCreator) + +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/range.h b/mindspore/lite/src/runtime/kernel/arm/fp32/range.h new file mode 100644 index 00000000000..cf90593df25 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/range.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RANGE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RANGE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/fp32/range.h" + +namespace mindspore::kernel { +class RangeCPUKernel : public LiteKernel { + public: + explicit RangeCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~RangeCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RANGE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/rank.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/rank.cc new file mode 100644 index 00000000000..9dbb21e9e9a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/rank.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/rank.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Rank; + +namespace mindspore::kernel { + +namespace { +constexpr int kInputNum = 1; +constexpr int kOutputNum = 1; +} // namespace + +int RankCPUKernel::Init() { return RET_OK; } + +int RankCPUKernel::ReSize() { return RET_OK; } + +int RankCPUKernel::Run() { + auto output_ptr = reinterpret_cast(outputs_.at(0)->Data()); + auto in_shape = inputs_[0]->shape(); + auto rank = in_shape.size(); + Rank(output_ptr, rank); + return RET_OK; +} + +kernel::LiteKernel *CpuRankFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, + const lite::Context *ctx, const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Rank); + + auto *kernel = new (std::nothrow) RankCPUKernel(opParameter, inputs, outputs); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Rank, CpuRankFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/rank.h b/mindspore/lite/src/runtime/kernel/arm/fp32/rank.h new file mode 100644 index 00000000000..e167508404f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/rank.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RANK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RANK_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/fp32/rank.h" + +namespace mindspore::kernel { +class RankCPUKernel : public LiteKernel { + public: + explicit RankCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~RankCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RANK_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc new file mode 100644 index 00000000000..7142e9a9ba7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc @@ -0,0 +1,247 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/reduce.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" +#include "src/runtime/kernel/arm/opclib/fp32/reduce.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_NULL_PTR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Reduce; +using mindspore::schema::ReduceMode; +using mindspore::schema::ReduceMode_ReduceMax; +using mindspore::schema::ReduceMode_ReduceMean; +using mindspore::schema::ReduceMode_ReduceMin; +using mindspore::schema::ReduceMode_ReduceProd; +using mindspore::schema::ReduceMode_ReduceSum; +using mindspore::schema::ReduceMode_ReduceSumSquare; + +namespace mindspore::kernel { +namespace { +constexpr size_t kInputNum = 1; +constexpr size_t kOutputNum = 1; +} // namespace + +int ReduceCPUKernel::CheckInputsOutputs() { + if (inputs_.size() != kInputNum) { + MS_LOG(ERROR) << "Reduce inputs size should be " << kInputNum << " but got " << inputs_.size(); + return RET_ERROR; + } + if (outputs_.size() != kOutputNum) { + MS_LOG(ERROR) << "Reduce outputs size should be " << kOutputNum << " but got " << outputs_.size(); + return RET_ERROR; + } + auto input = inputs_.at(0); + if (input == nullptr) { + MS_LOG(ERROR) << "Reduce input is nullptr"; + return RET_NULL_PTR; + } + auto output = outputs_.at(0); + if (output == nullptr) { + MS_LOG(ERROR) << "Reduce output is nullptr"; + return RET_NULL_PTR; + } + return RET_OK; +} + +int ReduceCPUKernel::CheckParameters() { + size_t input_rank = inputs_.at(0)->shape().size(); + if (static_cast(num_axes_) > input_rank) { + MS_LOG(ERROR) << "Reduce num of reduce axes " << num_axes_ << " larger than input rank " << input_rank; + return RET_ERROR; + } + for (auto i = 0; i < num_axes_; i++) { + if (static_cast(axes_[i]) < -input_rank || static_cast(axes_[i]) >= input_rank) { + MS_LOG(ERROR) << "Reduce got invalid axis " << axes_[i] << ", axis should be in [" << -input_rank << ", " + << input_rank - 1 << "]."; + return RET_ERROR; + } + if (axes_[i] < 0) { + axes_[i] += input_rank; + } + } + + if (num_axes_ == 0) { + for (int i = 0; i < input_rank; i++) { + axes_[i] = i; + } + } + + return RET_OK; +} + +int ReduceCPUKernel::Init() { + auto ret = CheckInputsOutputs(); + if (ret != RET_OK) { + return ret; + } + ret = CheckParameters(); + if (ret != RET_OK) { + return ret; + } + ret = MallocTmpBuffer(); + if (ret != RET_OK) { + return ret; + } + + switch (mode_) { + case static_cast(ReduceMode_ReduceSum): { + reducer_ = ReduceSum; + break; + } + case static_cast(ReduceMode_ReduceMean): { + reducer_ = ReduceMean; + break; + } + case static_cast(ReduceMode_ReduceMax): { + reducer_ = ReduceMax; + break; + } + case static_cast(ReduceMode_ReduceMin): { + reducer_ = ReduceMin; + break; + } + case static_cast(ReduceMode_ReduceProd): { + reducer_ = ReduceProd; + break; + } + case static_cast(ReduceMode_ReduceSumSquare): { + reducer_ = ReduceSumSquare; + break; + } + default: + MS_LOG(ERROR) << "Reduce unsupported reduce mode: " << mode_; + return RET_ERROR; + } + return RET_OK; +} + +int ReduceCPUKernel::CallReduceUnit(int task_id) { + auto ret = reducer_(outer_size_, inner_size_, axis_size_, src_data_, tmp_shape_.data(), dst_data_, task_id, + context_->threadNum); + return ret; +} + +int ReduceImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto reduce = reinterpret_cast(cdata); + auto error_code = reduce->CallReduceUnit(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Reduce Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ReduceCPUKernel::Run() { + tmp_shape_ = inputs_.at(0)->shape(); + src_data_ = static_cast(inputs_.at(0)->Data()); + for (int i = 0; i < tmp_shape_.size(); ++i) { + dst_data_ = data_buffers_[i]; + int axis = axes_[i]; + outer_size_ = 1; + for (int j = 0; j < axis; j++) { + outer_size_ *= tmp_shape_[j]; + } + inner_size_ = 1; + for (int k = axis + 1; k < static_cast(tmp_shape_.size()); k++) { + inner_size_ *= tmp_shape_[k]; + } + axis_size_ = tmp_shape_[axis]; + auto error_code = LiteBackendParallelLaunch(ReduceImpl, this, context_->threadNum); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]"; + return RET_ERROR; + } + tmp_shape_[axis] = 1; + src_data_ = dst_data_; + } + + int last_reduce_axis = axes_[num_axes_ - 1]; + outer_size_ = 1; + for (int i = 0; i < last_reduce_axis; i++) { + outer_size_ *= tmp_shape_[i]; + } + inner_size_ = 1; + for (int i = last_reduce_axis + 1; i < static_cast(tmp_shape_.size()); i++) { + inner_size_ *= tmp_shape_[i]; + } + axis_size_ = tmp_shape_[last_reduce_axis]; + dst_data_ = reinterpret_cast(outputs_.at(0)->Data()); + auto error_code = LiteBackendParallelLaunch(ReduceImpl, this, context_->threadNum); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]"; + return RET_ERROR; + } + + return RET_OK; +} + +kernel::LiteKernel *CpuReduceFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Reduce); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Reduce opParameter nullptr"; + return nullptr; + } + if (desc.type != schema::PrimitiveType_Reduce) { + MS_LOG(ERROR) << "Reduce op desc.type should be PrimitiveType_Reduce, got " << desc.type; + return nullptr; + } + auto *kernel = + new (std::nothrow) ReduceCPUKernel(reinterpret_cast(opParameter), inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Reduce new ReduceCPUKernel failed."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + } + return kernel; +} + +int ReduceCPUKernel::MallocTmpBuffer() { + auto input_shape = inputs_.at(0)->shape(); + for (auto i = 0; i < num_axes_ - 1; i++) { + int axis = axes_[i]; + size_t size = 1; + for (auto j = 0; j < input_shape.size(); j++) { + if (static_cast(axis) != j) { + size *= input_shape[j]; + } + } + float *buffer = reinterpret_cast(malloc(size * sizeof(float))); + if (buffer == nullptr) { + MS_LOG(ERROR) << "Malloc data failed."; + return RET_ERROR; + } + data_buffers_.emplace_back(buffer); + input_shape[axis] = 1; + } +} + +REG_KERNEL(kCPU, PrimitiveType_Reduce, CpuReduceFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.h b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.h new file mode 100644 index 00000000000..41a12203539 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.h @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REDUCE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REDUCE_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/opclib/fp32/reduce.h" +#include "ir/anf.h" +using mindspore::schema::ReduceMode; + +namespace mindspore::kernel { +class ReduceCPUKernel : public LiteKernel { + typedef int (*Reducer)(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num); + + public: + ReduceCPUKernel(ReduceParameter *param, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(reinterpret_cast(param), inputs, outputs), + context_(ctx), + keep_dims_(param->keep_dims_), + num_axes_(param->num_axes_), + mode_(param->mode_) { + memcpy(axes_, param->axes_, sizeof(param->axes_)); + } + ~ReduceCPUKernel() { + for (auto i = 0; i < data_buffers_.size(); i++) { + float *buffer = data_buffers_[i]; + if (buffer != nullptr) { + free(buffer); + buffer = nullptr; + } + } + src_data_ = nullptr; + dst_data_ = nullptr; + } + + int Init() override; + int ReSize() override { return 0; }; + int Run() override; + int CallReduceUnit(int task_id); + + private: + int CheckInputsOutputs(); + int CheckParameters(); + int MallocTmpBuffer(); + + private: + const lite::Context *context_ = nullptr; + bool keep_dims_; + int axes_[REDUCE_MAX_AXES_NUM]; + int num_axes_; + int mode_; + + private: + std::vector data_buffers_; + int outer_size_; + int inner_size_; + int axis_size_; + std::vector tmp_shape_; + const float *src_data_; + float *dst_data_; + Reducer reducer_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REDUCE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.cc new file mode 100644 index 00000000000..7cd932c33fd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/reshape.h" +#include +#include "src/runtime/kernel/arm/opclib/reshape.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Reshape; + +namespace mindspore::kernel { +int ReshapeCPUKernel::Init() { + ReshapeBaseCPUKernel::Init(); + return RET_OK; +} + +int ReshapeCPUKernel::ReSize() { return RET_OK; } + +int ReshapeCPUKernel::Run() { + auto input_ptr = inputs_.at(kInputIndex)->Data(); + auto output_ptr = outputs_.at(kOutputIndex)->Data(); + size_t data_size = inputs_.at(kInputIndex)->Size(); + Reshape(input_ptr, output_ptr, data_size); + return RET_OK; +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.h b/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.h new file mode 100644 index 00000000000..f366e739d97 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RESHAPE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RESHAPE_H_ + +#include +#include "src/lite_kernel.h" + +#include "include/context.h" +#include "src/runtime/kernel/arm/base/reshape_base.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class ReshapeCPUKernel : public ReshapeBaseCPUKernel { + public: + ReshapeCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ReshapeBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~ReshapeCPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RESHAPE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/resize.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/resize.cc new file mode 100644 index 00000000000..7b41b3e28ee --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/resize.cc @@ -0,0 +1,242 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/resize.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/opclib/resize.h" +#include "src/runtime/kernel/arm/opclib/pack.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_INVALID_OP_ATTR; +using mindspore::lite::RET_NULL_PTR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +namespace { +constexpr int kInputNum = 1; +constexpr int kOutputNum = 1; +constexpr int kRank = 4; +} // namespace + +int ResizeCPUKernel::CheckParameters() { + auto parameter = reinterpret_cast(opParameter); + if (parameter == nullptr) { + MS_LOG(ERROR) << "cast ResizeParameter failed."; + return RET_NULL_PTR; + } + method_ = parameter->method_; + if (method_ != schema::ResizeMethod_BILINEAR && method_ != schema::ResizeMethod_NEAREST_NEIGHBOR) { + MS_LOG(ERROR) << "Resize method should be bilinear or nearest_neighbor, but got " << method_; + return RET_INVALID_OP_ATTR; + } + new_height_ = parameter->new_height_; + if (new_height_ < 1) { + MS_LOG(ERROR) << "Resize new_height should >= 1, but got " << new_height_; + return RET_INVALID_OP_ATTR; + } + new_width_ = parameter->new_width_; + if (new_width_ < 1) { + MS_LOG(ERROR) << "Resize new_width should >= 1, but got " << new_width_; + return RET_INVALID_OP_ATTR; + } + align_corners_ = parameter->align_corners_; + preserve_aspect_ratio = parameter->preserve_aspect_ratio_; + if (preserve_aspect_ratio) { + MS_LOG(ERROR) << "Resize currently not support preserve_aspect_ratio true"; + return RET_ERROR; + } + return RET_OK; +} + +int ResizeCPUKernel::CheckInputsOuputs() { + if (inputs_.size() != kInputNum) { + MS_LOG(ERROR) << "Resize input num should be " << kInputNum << ", but got " << inputs_.size(); + return RET_ERROR; + } + auto input = inputs_.at(0); + if (input == nullptr) { + return RET_NULL_PTR; + } + if (outputs_.size() != kOutputNum) { + MS_LOG(ERROR) << "Resize output num should be " << kOutputNum << ", but got " << outputs_.size(); + return RET_ERROR; + } + auto output = outputs_.at(0); + if (output == nullptr) { + return RET_NULL_PTR; + } + return RET_OK; +} + +int ResizeCPUKernel::Init() { + auto ret = CheckParameters(); + if (ret != RET_OK) { + return ret; + } + ret = CheckInputsOuputs(); + if (ret != RET_OK) { + return ret; + } + + auto output = outputs_.at(0); + auto input = inputs_.at(0); + auto input_shape = input->shape(); + if (input_shape.size() != kRank) { + return RET_ERROR; + } + schema::Format execute_format; + size_t exec_input_size; + switch (method_) { + case schema::ResizeMethod_BILINEAR: { + execute_format = schema::Format_NC4HW4; + output->SetFormat(schema::Format_NC4HW4); + exec_input_size = input->ElementsC4Num(); + break; + } + case schema::ResizeMethod_NEAREST_NEIGHBOR: { + execute_format = schema::Format_NHWC; + output->SetFormat(schema::Format_NHWC); + exec_input_size = input->ElementsNum(); + break; + } + default: { + MS_LOG(ERROR) << "Resize unknown method " << method_; + return RET_ERROR; + } + } + + auto input_format = input->GetFormat(); + if (input_format != execute_format) { + auto input_type = input->data_type(); + layout_convertor_ = LayoutTransform(input_type, input_format, execute_format); + exec_input_data_ = reinterpret_cast(malloc(exec_input_size * sizeof(float))); + if (exec_input_data_ == nullptr) { + return RET_NULL_PTR; + } + } + + return RET_OK; +} + +int ResizeImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto resize = reinterpret_cast(cdata); + auto error_code = resize->RunImpl(task_id); + if (error_code != OPCLIB_OK) { + MS_LOG(ERROR) << "Resize Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ResizeCPUKernel::RunImpl(int task_id) { + auto input = inputs_.at(0); + auto input_data = reinterpret_cast(input->Data()); + if (input_data == nullptr) { + return RET_NULL_PTR; + } + auto output_data = reinterpret_cast(outputs_.at(0)->Data()); + if (output_data == nullptr) { + return RET_NULL_PTR; + } + auto input_shape = input->shape(); + if (input_shape.size() != kRank) { + return RET_ERROR; + } + if (context_ == nullptr) { + return RET_NULL_PTR; + } + + int ret = 0; + switch (method_) { + case schema::ResizeMethod_BILINEAR: { + if (layout_convertor_ != nullptr) { + layout_convertor_(input_data, exec_input_data_, input->Batch(), input->Height() * input->Width(), + input->Channel()); + ret = ResizeBilinear(exec_input_data_, output_data, inputs_[0]->shape().data(), outputs_[0]->shape().data(), + align_corners_, task_id, context_->threadNum); + } else { + ret = ResizeBilinear(input_data, output_data, inputs_[0]->shape().data(), outputs_[0]->shape().data(), + align_corners_, task_id, context_->threadNum); + } + break; + } + case schema::ResizeMethod_NEAREST_NEIGHBOR: { + if (align_corners_) { + MS_LOG(ERROR) << "ResizeNearestNeighbor not support align_corners."; + return RET_ERROR; + } + if (layout_convertor_ != nullptr) { + layout_convertor_(input_data, exec_input_data_, input->Batch(), input->Height() * input->Width(), + input->Channel()); + ret = ResizeNearestNeighbor(exec_input_data_, output_data, input_shape.data(), outputs_[0]->shape().data(), + task_id, context_->threadNum); + } else { + ret = ResizeNearestNeighbor(input_data, output_data, input_shape.data(), outputs_[0]->shape().data(), task_id, + context_->threadNum); + } + break; + } + case schema::ResizeMethod_UNKNOW: + default: { + MS_LOG(ERROR) << "Resize unknown method " << method_; + ret = OPCLIB_ERR; + } + } + return ret; +} + +int ResizeCPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(ResizeImpl, this, context_->threadNum); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Resize run error, error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuResizeFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) ResizeCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ResizeCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Resize, CpuResizeFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/resize.h b/mindspore/lite/src/runtime/kernel/arm/fp32/resize.h new file mode 100644 index 00000000000..2624d136cf9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/resize.h @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RESIZE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RESIZE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/resize.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" + +using mindspore::schema::PrimitiveType_Resize; +using mindspore::schema::ResizeMethod; + +namespace mindspore::kernel { +class ResizeCPUKernel : public LiteKernel { + public: + ResizeCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), context_(ctx) {} + + ~ResizeCPUKernel() { + if (exec_input_data_ != nullptr) { + free(exec_input_data_); + exec_input_data_ = nullptr; + } + } + + int Init() override; + int ReSize() override { return 0; }; + int Run() override; + int RunImpl(int task_id); + + protected: + const lite::Context *context_; + + private: + int CheckParameters(); + int CheckInputsOuputs(); + + private: + ResizeMethod method_; + int64_t new_height_; + int64_t new_width_; + bool align_corners_; + bool preserve_aspect_ratio; + LayoutConvertor layout_convertor_ = nullptr; + float *exec_input_data_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RESIZE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.cc new file mode 100644 index 00000000000..515e41d2c13 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.cc @@ -0,0 +1,161 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/reverse.h" +#include +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/opclib/fp32/reverse.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Reverse; + +namespace mindspore::kernel { + +int ReverseCPUKernel::Stride(int index) { + int i, stride = 1; + for (i = index + 1; i < inputs_[0]->shape().size(); ++i) { + stride *= inputs_[0]->shape()[i]; + } + return stride; +} + +int ReverseCPUKernel::ReSize() { + auto *param = reinterpret_cast(opParameter); + auto input_shape = inputs_[0]->shape(); + if (param->num_axis_ > input_shape.size()) { + MS_LOG(ERROR) << "Reverse dims : " << param->num_axis_ + << "is greater than input shape size :" << input_shape.size(); + return RET_ERROR; + } + if (input_shape.size() > REVERSE_SHAPE_MAX_SIZE) { + MS_LOG(ERROR) << "input dimension num should <= " << REVERSE_SHAPE_MAX_SIZE; + return RET_ERROR; + } + + if (tmp_ != nullptr) { + free(tmp_); + tmp_ = nullptr; + } + tmp_ = reinterpret_cast(malloc(data_size_ * sizeof(int))); + if (tmp_ == nullptr) { + MS_LOG(ERROR) << "Reverse Malloc tmp_ error!"; + return RET_ERROR; + } + (void)memset(tmp_, 0, data_size_ * sizeof(int)); + + for (int i = 0; i < param->num_axis_; i++) { + int axis = param->axis_[i]; + int stride = Stride(axis); + strides_[i] = stride; + inCount_[i] = input_shape[axis]; + outCount_[i] = 1; + for (int j = 0; j < axis; j++) { + outCount_[i] *= input_shape[j]; + } + } + + int out, in, C, m; + for (int i = 0; i < data_size_; ++i) { + int tmp = i; + for (int j = 0; j < param->num_axis_; ++j) { + C = inCount_[j]; + out = tmp / (C * strides_[j]); + in = tmp / strides_[j] - out * C; + m = tmp % strides_[j]; + tmp = out * C * strides_[j] + strides_[j] * (C - 1 - in) + m; + } + tmp_[i] = tmp; + } + + return RET_OK; +} + +int ReverseCPUKernel::Init() { + data_size_ = inputs_.at(0)->ElementsNum(); + thread_sz_count_ = MSMIN(thread_count_, data_size_); + thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); + int ret = ReSize(); + return ret; +} + +int ReverseRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->DoReverse(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "reverseRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int ReverseCPUKernel::DoReverse(int task_id) { + int count = MSMIN(thread_sz_stride_, data_size_ - task_id * thread_sz_stride_); + if (count <= 0) { + return RET_OK; + } + int offset = task_id * thread_sz_stride_; + auto ret = Reverse(in_ptr_ + offset, out_ptr_, thread_sz_stride_, tmp_ + offset); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ReverseRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int ReverseCPUKernel::Run() { + in_ptr_ = reinterpret_cast(inputs_[0]->Data()); + out_ptr_ = reinterpret_cast(outputs_[0]->Data()); + int ret = LiteBackendParallelLaunch(ReverseRun, this, thread_sz_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Reverse run error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +kernel::LiteKernel *CpuReverseFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "opParameter is NULL! "; + return nullptr; + } + auto *kernel = new (std::nothrow) ReverseCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Kernel is NULL! name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Reverse, CpuReverseFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.h b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.h new file mode 100644 index 00000000000..a64a31cc668 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REVERSE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REVERSE_H_ + +#include +#include "src/lite_kernel.h" + +#include "include/context.h" + +#define REVERSE_STRIDE_MAX_SIZE 4 + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class ReverseCPUKernel : public LiteKernel { + public: + ReverseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) {} + ~ReverseCPUKernel() { + if (tmp_ != nullptr) { + free(tmp_); + tmp_ = nullptr; + } + } + + int Init() override; + int ReSize() override; + int Run() override; + int Stride(int index); + int DoReverse(int task_id); + + private: + int thread_count_; + int thread_sz_count_; + int thread_sz_stride_; + int data_size_; + int strides_[REVERSE_STRIDE_MAX_SIZE]; + int inCount_[REVERSE_STRIDE_MAX_SIZE]; + int outCount_[REVERSE_STRIDE_MAX_SIZE]; + const Context *ctx_; + int *tmp_ = nullptr; + float *in_ptr_; + float *out_ptr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REVERSE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.cc new file mode 100644 index 00000000000..9677d3b337c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.cc @@ -0,0 +1,116 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/reverse_sequence.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_ReverseSequence; + +namespace mindspore::kernel { +int ReverseSequenceCPUKernel::Init() { + auto input0 = inputs_.at(0); + auto input1 = inputs_.at(1); + auto output = outputs_.at(0); + MS_ASSERT(input0 != nullptr); + MS_ASSERT(input1 != nullptr); + MS_ASSERT(output != nullptr); + + auto para = reinterpret_cast(opParameter); + + ConvertAxisToPositive(input0->shape(), &(para->batch_axis_)); + ConvertAxisToPositive(input0->shape(), &(para->seq_axis_)); + + para->ndim_ = input0->shape().size(); + for (int i = 0; i < para->ndim_; i++) { + para->input_shape0_[i] = input0->DimensionSize(i); + para->output_shape_[i] = output->DimensionSize(i); + } + + int less_axis = MSMIN(para->batch_axis_, para->seq_axis_); + int greater_axis = MSMAX(para->batch_axis_, para->seq_axis_); + + para->outer_count_ = CalcCountPreAxis(input0->shape(), less_axis); + para->outer_stride_ = input0->DimensionSize(less_axis) * CalcCountAfterAxis(input0->shape(), less_axis); + + para->inner_count_ = 1; + for (int i = less_axis + 1; i < greater_axis; ++i) { + para->inner_count_ *= input0->DimensionSize(i); + } + + para->inner_stride_ = input0->DimensionSize(greater_axis) * CalcCountAfterAxis(input0->shape(), greater_axis); + + para->copy_byte_size_ = sizeof(float) * CalcCountAfterAxis(input0->shape(), greater_axis); + para->total_data_size_ = input0->Size(); + return RET_OK; +} + +void ReverseSequenceCPUKernel::ConvertAxisToPositive(const std::vector shape, int *axis) { + if (axis != nullptr && *axis < 0) { + *axis += shape.size(); + } +} + +int ReverseSequenceCPUKernel::CalcCountPreAxis(const std::vector shape, int axis) { + int count = 1; + for (int i = 0; i < axis; ++i) { + count *= shape[i]; + } + return count; +} +int ReverseSequenceCPUKernel::CalcCountAfterAxis(const std::vector shape, int axis) { + int count = 1; + for (int i = axis + 1; i < shape.size(); ++i) { + count *= shape[i]; + } + return count; +} + +int ReverseSequenceCPUKernel::ReSize() { return RET_OK; } + +int ReverseSequenceCPUKernel::Run() { + float *input0 = reinterpret_cast(inputs_.at(0)->Data()); + int *input1 = reinterpret_cast(inputs_.at(1)->Data()); + float *output = reinterpret_cast(outputs_.at(0)->Data()); + ReverseSequence(input0, input1, output, reinterpret_cast(opParameter)); + return RET_OK; +} + +kernel::LiteKernel *CpuReverseSequenceFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *parameter, const lite::Context *ctx, + const KernelKey &desc) { + MS_ASSERT(parameter != nullptr); + auto *kernel = new (std::nothrow) ReverseSequenceCPUKernel(parameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_ReverseSequence, CpuReverseSequenceFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.h b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.h new file mode 100644 index 00000000000..e98eb020190 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REVERSE_SEQUENCE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REVERSE_SEQUENCE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/reverse_sequence.h" + +namespace mindspore::kernel { +class ReverseSequenceCPUKernel : public LiteKernel { + public: + ReverseSequenceCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~ReverseSequenceCPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + void ConvertAxisToPositive(const std::vector shape, int *axis); + int CalcCountPreAxis(const std::vector shape, int axis); + int CalcCountAfterAxis(const std::vector shape, int axis); +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REVERSE_SEQUENCE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc new file mode 100644 index 00000000000..5432e0bdbc8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc @@ -0,0 +1,168 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/scale.h" +#include +#include +#include "src/runtime/kernel/arm/opclib/scale.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Scale; + +namespace mindspore::kernel { +namespace { +constexpr int kScaleInputNum = 1; +constexpr int kScaleOutputNum = 1; +} // namespace +int ScaleCPUKernel::Init() { + auto param = reinterpret_cast(opParameter); + auto in_tensor = inputs_.front(); + auto scale = inputs_.at(1); + + if (inputs_.size() < 2 || inputs_.size() > 3) { + MS_LOG(ERROR) << "inputs to Scale operator should be 2 or 3, but " << inputs_.size() << " is given."; + return RET_ERROR; + } + + if (param->axis_ < 0) { + MS_LOG(ERROR) << "axis illegal."; + return RET_ERROR; + } + if (param->num_axis_ < 1 || param->num_axis_ + param->axis_ >= in_tensor->shape().size()) { + MS_LOG(ERROR) << "number of axis illegal"; + return RET_ERROR; + } + + param->channel_ = 1; + param->out_count_ = 1; + param->in_stride_ = 1; + int cur_axis; + for (cur_axis = 0; cur_axis < param->axis_; cur_axis++) { + param->out_count_ *= in_tensor->shape()[cur_axis]; + } + for (int i = 0; i < param->num_axis_; i++) { + param->channel_ *= in_tensor->shape()[(cur_axis++)]; + } + for (int i = cur_axis; i < in_tensor->shape().size(); i++) { + param->in_stride_ *= in_tensor->shape()[cur_axis]; + } + if (scale->shape().back() != param->channel_ || scale->shape().size() > 2) { + MS_LOG(ERROR) << "scale shape illegal."; + return RET_ERROR; + } + if (inputs_.size() == 3) { + if ((inputs_.at(2))->shape().back() != param->channel_ || (inputs_.at(2))->shape().size() > 2) { + MS_LOG(ERROR) << "offset shape illegal."; + return RET_ERROR; + } + } + + input_ptr_ = reinterpret_cast(inputs_.front()->Data()); + scale_ = reinterpret_cast(inputs_.at(1)->Data()); + if (inputs_.size() == 3) { + offset_ = reinterpret_cast(inputs_.at(2)->Data()); + has_offset_ = true; + } else { + offset_ = nullptr; + has_offset_ = false; + } + output_ptr_ = reinterpret_cast(outputs_.front()->Data()); + + num_unit_ = param->out_count_ * param->channel_; + unit_size_ = param->in_stride_; + thread_n_num_ = MSMIN(thread_num_, num_unit_); + thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_); + return RET_OK; +} + +int ScaleCPUKernel::Scale(int task_id) { + int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_); + if (num_unit_thread <= 0) { + return RET_OK; + } + int thread_offset = task_id * thread_n_stride_; + int ret; + if (has_offset_) { + ret = DoScale(input_ptr_, output_ptr_, scale_, offset_, thread_offset, num_unit_thread, + reinterpret_cast(opParameter)); + } else { + ret = DoScale(input_ptr_, output_ptr_, scale_, thread_offset, num_unit_thread, + reinterpret_cast(opParameter)); + } + + if (ret != RET_OK) { + MS_LOG(ERROR) << "Scale error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ScaleCPUKernel::ReSize() { return RET_OK; } + +int ScaleRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->Scale(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ScaleRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ScaleCPUKernel::Run() { + int ret = LiteBackendParallelLaunch(ScaleRun, this, thread_n_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Scale error error_code[" << ret << "]"; + return RET_ERROR; + } + + return RET_OK; +} + +kernel::LiteKernel *CpuScaleFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_Scale); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "opParameter is nullptr"; + return nullptr; + } + auto *kernel = new (std::nothrow) ScaleCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "New kernel fails."; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Scale, CpuScaleFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h new file mode 100644 index 00000000000..052ae0a115e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SCALE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SCALE_H_ + +#include +#include "src/lite_kernel.h" + +namespace mindspore::kernel { + +class ScaleCPUKernel : public LiteKernel { + public: + explicit ScaleCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), thread_num_(ctx->threadNum) {} + ~ScaleCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int Scale(int task_id); + + private: + int thread_num_; + int thread_n_stride_; + int thread_n_num_; + int num_unit_; + int unit_size_; + float *input_ptr_; + float *scale_; + float *offset_; + float *output_ptr_; + bool has_offset_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SCALE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.cc new file mode 100644 index 00000000000..a012e2c5ab6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.cc @@ -0,0 +1,186 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/scatter_nd.h" +#include +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_ScatterND; + +namespace mindspore::kernel { +namespace { + constexpr int kScatterNDInputNum = 3; + constexpr int kScatterNDOutputNum = 1; + constexpr int kScatterShapeIndex = 0; + constexpr int kScatterIndicesIndex = 1; + constexpr int kScatterUpdateIndex = 2; +} // namespace +int ScatterNDCPUKernel::Init() { + auto shape = inputs_.at(kScatterShapeIndex); + auto indices = inputs_.at(kScatterIndicesIndex); + auto update = inputs_.at(kScatterUpdateIndex); + + update_ptr_ = reinterpret_cast(update->Data()); + output_ptr_ = reinterpret_cast(outputs_.at(0)->Data()); + + // check indices shape + auto shape_rank = shape->ElementsNum(); + auto shape_data = reinterpret_cast(shape->Data()); + auto indice_unit_rank = indices->shape().back(); + if (indice_unit_rank > shape_rank) { + MS_LOG(ERROR) << "Value of last dimension of indices is greater than shape rank."; + return RET_ERROR; + } + + if (indices->shape().size() < 2) { + MS_LOG(ERROR) << "Indices dimension smaller than 2."; + return RET_ERROR; + } + + // check consistency of the shape indices and shape + auto update_rank = static_cast(update->shape().size()); + auto indices_shape = indices->shape(); + if (update_rank != indices->shape().size() - 1 + shape_rank - indice_unit_rank) { + MS_LOG(ERROR) << "Update, shape rank and indices rank inconsistent."; + return RET_ERROR; + } + // check update shape + auto update_shape = update->shape(); + for (size_t i = 0; i < indices_shape.size() - 1; i++) { + if (update_shape[i] != indices_shape[i]) { + MS_LOG(ERROR) << "Value of " << i << " th dimension of indices is not equal to that of update."; + return RET_ERROR; + } + } + for (size_t i = 0; i < shape->ElementsNum() - (indices_shape.size() - 1); i++) { + if (update_shape[i + indices_shape.size() - 1] != shape_data[i + indices_shape.size() - 1]) { + MS_LOG(ERROR) << "Value of " << i + indices_shape.size() - 1 + << " th dimension of indices is not equal to the corresbonding dimension of shape."; + return RET_ERROR; + } + } + // todo check indeices out of range + // for (size_t i = 0; i < static_cast(indice_unit_rank); i++) {} + + // calculate unit_size_ + unit_size_ = 1; + for (int i = indices_shape.size() - 1; i < update_rank; i++) { + unit_size_ *= update_shape[i]; + } + + // calculate offsets + int out_stride = 1; + out_strides_.push_back(1); + for (int i = indice_unit_rank - 2; i >= 0; i--) { + out_stride *= shape_data[i + 1]; + out_strides_.push_back(out_stride); + } + + num_unit_ = 1; + num_unit_ *= update_shape[indices_shape.size() - 2]; + for (int i = indices_shape.size() - 3; i >= 0; i--) { + num_unit_ *= update_shape[i]; + } + + int *indices_ptr = reinterpret_cast(indices->Data()); + for (int i = 0; i < num_unit_; i++) { + int tmp_stride = 0; + for (int j = 0; j < indice_unit_rank; j++) { + tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides_[j] * unit_size_; + } + output_unit_offsets_.push_back(tmp_stride); + } + + thread_n_num_ = MSMIN(thread_num_, num_unit_); + thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_); + return RET_OK; +} + +int ScatterNDCPUKernel::ReSize() { return 0; } + +int ScatterNDCPUKernel::ScatterND(int task_id) { + int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_); + if (num_unit_thread <= 0) { + return RET_OK; + } + int offset = task_id * thread_n_stride_; + MS_LOG(ERROR) << "offset " << offset << std::endl; + auto ret = DoScatterND(output_ptr_, update_ptr_ + offset * unit_size_, output_unit_offsets_.data() + offset, + unit_size_, num_unit_thread); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ScatterND error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ScatterNDRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->ScatterND(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ScatterNDRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ScatterNDCPUKernel::Run() { + int ret = LiteBackendParallelLaunch(ScatterNDRun, this, thread_n_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ScatterND error error_code[" << ret << "]"; + return RET_ERROR; + } + + return RET_OK; +} + +kernel::LiteKernel *CpuScatterNDFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_ScatterND); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "desc type is not scatterND"; + return nullptr; + } + auto *kernel = new (std::nothrow) ScatterNDCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "New kernel fails."; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != 0) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_ScatterND, CpuScatterNDFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.h b/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.h new file mode 100644 index 00000000000..095d459b02c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SCATTER_ND_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SCATTER_ND_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/scatter_nd.h" + +namespace mindspore::kernel { + +class ScatterNDCPUKernel : public LiteKernel { + public: + explicit ScatterNDCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), thread_num_(ctx->threadNum) {} + ~ScatterNDCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int ScatterND(int task_id); + + private: + int thread_num_; + int thread_n_num_; + int thread_n_stride_; + int num_unit_; + int unit_size_; + float *output_ptr_; + float *update_ptr_; + std::vector out_strides_; + std::vector output_unit_offsets_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SCATTER_ND_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/shape.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/shape.cc new file mode 100644 index 00000000000..70bed3f05ec --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/shape.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/shape.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Shape; + +namespace mindspore::kernel { +namespace { + constexpr int kShapeInputNum = 1; + constexpr int kShapeOutputNum = 1; +} // namespace +int ShapeCPUKernel::Init() { return RET_OK; } + +int ShapeCPUKernel::ReSize() { return RET_OK; } + +int ShapeCPUKernel::Run() { + auto out_tensor = outputs_.front(); + auto in_tensor = inputs_.front(); + if (in_tensor == nullptr || out_tensor == nullptr) { + MS_LOG(ERROR) << "null pointer dereferencing."; + return RET_ERROR; + } + if (in_tensor->Data() == nullptr || out_tensor->Data() == nullptr) { + MS_LOG(ERROR) << "null pointer dereferencing."; + return RET_ERROR; + } + + for (int i = 0; i < in_tensor->shape().size(); i++) { + reinterpret_cast(out_tensor->Data())[i] = in_tensor->shape()[i]; + } + + return RET_OK; +} + +kernel::LiteKernel *CpuShapeFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, + const lite::Context *ctx, const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_Shape); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "desc type is not Shape"; + return nullptr; + } + auto *kernel = new (std::nothrow) ShapeCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "New kernel fails."; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != 0) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Shape, CpuShapeFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/shape.h b/mindspore/lite/src/runtime/kernel/arm/fp32/shape.h new file mode 100644 index 00000000000..e3c6ce333fa --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/shape.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SHAPE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SHAPE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/shape.h" + +namespace mindspore::kernel { + +class ShapeCPUKernel : public LiteKernel { + public: + explicit ShapeCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~ShapeCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SHAPE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc new file mode 100644 index 00000000000..805bf57f1b2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/slice.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/opclib/fp32/slice.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Slice; + +namespace mindspore::kernel { + +int SliceCPUKernel::Init() { + auto *param = reinterpret_cast(opParameter); + auto input_shape = inputs_[0]->shape(); + if (input_shape.size() != param->param_length_) { + MS_LOG(ERROR) << "Input begin's lenth " << param->param_length_ << "is not equal to input shape size " + << input_shape.size(); + return RET_ERROR; + } + if (input_shape.size() > SLICE_SHAPE_MAX_SIZE) { + MS_LOG(ERROR) << "input dimension num should <= " << SLICE_SHAPE_MAX_SIZE; + return RET_ERROR; + } + + for (size_t i = 0; i < input_shape.size(); ++i) { + param->shape_[i] = input_shape[i]; + } + return RET_OK; +} + +int SliceCPUKernel::Run() { + SliceParameter *param = reinterpret_cast(opParameter); + const float *input_data = reinterpret_cast(inputs_[0]->Data()); + float *output_data = reinterpret_cast(outputs_[0]->Data()); + + return DoSlice(input_data, param, output_data); +} + +kernel::LiteKernel *CpuSliceFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) SliceCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new SliceCPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Slice, CpuSliceFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/slice.h b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.h new file mode 100644 index 00000000000..2591bf15c4c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SLICE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SLICE_H_ + +#include +#include "src/lite_kernel.h" + + +namespace mindspore::kernel { +class SliceCPUKernel : public LiteKernel { + public: + SliceCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) : LiteKernel(parameter, inputs, outputs) {} + ~SliceCPUKernel() = default; + + int Init() override; + int ReSize() override { + return 0; + } + int Run() override; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SLICE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc new file mode 100644 index 00000000000..9522b2227a4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/softmax.h" +#include +#include +#include "src/runtime/kernel/arm/opclib/fp32/softmax.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_SoftMax; + +namespace mindspore::kernel { +int SoftmaxCPUKernel::Init() { + auto input_tensor = inputs_.front(); + auto in_shape = input_tensor->shape(); + auto in_dims = in_shape.size(); + int ele_size = 1; + (reinterpret_cast(opParameter))->n_dim_ = in_dims; + for (size_t i = 0; i < in_dims; i++) { + (reinterpret_cast(opParameter))->input_shape_[i] = in_shape[i]; + ele_size *= in_shape[i]; + } + (reinterpret_cast(opParameter))->element_size_ = ele_size; + + // malloc tmp buffer + auto axis = reinterpret_cast(opParameter)->axis_; + sum_data = reinterpret_cast(malloc(in_shape[axis] * sizeof(float))); + memset(sum_data, 0, in_shape[axis] * sizeof(float)); + return RET_OK; +} + +int SoftmaxCPUKernel::ReSize() { return RET_OK; } + +int SoftmaxCPUKernel::Run() { + auto input_ptr = reinterpret_cast(inputs_.at(kInputIndex)->Data()); + auto output_ptr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + Softmax(input_ptr, output_ptr, sum_data, reinterpret_cast(opParameter)); + return RET_OK; +} + +kernel::LiteKernel *CpuSoftmaxFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_SoftMax); + auto *kernel = new (std::nothrow) SoftmaxCPUKernel(opParameter, inputs, outputs); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_SoftMax, CpuSoftmaxFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h new file mode 100644 index 00000000000..61c1951af33 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SOFTMAX_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SOFTMAX_H_ + +#include +#include "src/lite_kernel.h" + + +namespace mindspore::kernel { +class SoftmaxCPUKernel : public LiteKernel { + public: + SoftmaxCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~SoftmaxCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + float *sum_data; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SOFTMAX_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.cc new file mode 100644 index 00000000000..4612649baf7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/sparse_to_dense.h" +#include +#include "schema/model_generated.h" +#include "schema/ops_generated.h" +#include "src/runtime/kernel/arm/opclib/sparse_to_dense.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_SparseToDense; + +namespace mindspore::kernel { +int SparseToDenseCPUKernel::Init() { + s2d_param_->op_parameter_.thread_num_ = thread_count_; + return RET_OK; +} + +int SparseToDenseCPUKernel::DoExcute(int task_id) { + SparseToDense(input_data_, output_shape_, snum_, dnum_, sp_num_, output_data, s2d_param_, task_id); + return RET_OK; +} + +int SparseToDenseRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto s2ddata = reinterpret_cast(cdata); + auto ret = s2ddata->DoExcute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SparseToDenseRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} +int SparseToDenseCPUKernel::Run() { + auto input = inputs_.at(0); + auto input1 = inputs_.at(1); + auto input2 = inputs_.at(2); + auto input3 = inputs_.at(3); + auto output0 = outputs_.at(0); + + input_data_ = reinterpret_cast(input->Data()); + total_number_ = reinterpret_cast(input1->Data()); + snum_ = reinterpret_cast(input2->Data()); + dnum_ = reinterpret_cast(input3->Data()); + sp_num_ = static_cast(input->ElementsNum() / 2); + + output_data = reinterpret_cast(outputs_.at(0)->Data()); + std::vector temp_shape = output0->shape(); + output_shape_ = reinterpret_cast(temp_shape.data()); + + auto ret = LiteBackendParallelLaunch(SparseToDenseRun, this, s2d_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SparseToDenseRun error: error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuSparseToDenseFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "input opParameter is nullptr!"; + return nullptr; + } + + auto *kernel = new (std::nothrow) SparseToDenseCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new SparseToDenseCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_SparseToDense, CpuSparseToDenseFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.h b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.h new file mode 100644 index 00000000000..c13bc00f640 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSETODENSE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSETODENSE_H_ + +#include +#include "src/lite_kernel.h" + +#include "include/context.h" +#include "src/runtime/kernel/arm/opclib/sparse_to_dense.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class SparseToDenseCPUKernel : public LiteKernel { + public: + SparseToDenseCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) { + s2d_param_ = (reinterpret_cast(opParameter)); + } + ~SparseToDenseCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override; + int DoExcute(int task_id); + + protected: + int thread_count_; + const Context *ctx_; + SparseToDenseParameter *s2d_param_; + + private: + int *input_data_; + int *total_number_; + int sp_num_; + float *snum_; + float *dnum_; + float *output_data; + int *output_shape_; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSETODENSE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/split.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/split.cc new file mode 100644 index 00000000000..c149673402e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/split.cc @@ -0,0 +1,130 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "src/runtime/kernel/arm/fp32/split.h" +#include "src/runtime/kernel/arm/opclib/split.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Split; + +namespace mindspore::kernel { + +int SplitCPUKernel::Init() { + auto in_tensor = inputs_.front(); + input_ptr_ = reinterpret_cast(in_tensor->Data()); + auto input_shape = in_tensor->shape(); + auto param = reinterpret_cast(opParameter); + + param->strides_[input_shape.size() - 1] = 1; + for (int i = input_shape.size() - 2; i >= 0; i--) { + param->strides_[i] = param->strides_[i + 1] * input_shape[i + 1]; + } + + param->split_count_ = + param->strides_[0] * input_shape[0] / (input_shape[param->split_dim_] * param->strides_[param->split_dim_]); + for (int i = 0; i < param->num_split_; i++) { + output_ptr_.push_back(reinterpret_cast(outputs_.at(i)->Data())); + } + param->n_dims_ = input_shape.size(); + + if (param->split_sizes_[0] == 0) { + if (input_shape[param->split_dim_] % param->num_split_ != 0) { + MS_LOG(ERROR) << "Default split size is not usable."; + return RET_ERROR; + } + int split_size = input_shape[param->split_dim_] / param->num_split_; + for (int i = 0; i < param->num_split_; i++) { + param->split_sizes_[i] = split_size; + } + } + + num_unit_ = param->split_count_ * param->num_split_; + unit_size_ = param->strides_[param->split_dim_]; + thread_n_num_ = MSMIN(thread_num_, num_unit_); + thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_); + return RET_OK; +} + +int SplitCPUKernel::ReSize() { return RET_OK; } + +int SplitCPUKernel::Split(int task_id) { + int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_); + if (num_unit_thread <= 0) { + return RET_OK; + } + int thread_offset = task_id * thread_n_stride_; + auto ret = DoSplit(input_ptr_, output_ptr_.data(), inputs_.front()->shape().data(), thread_offset, num_unit_thread, + reinterpret_cast(opParameter)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Split error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int SplitRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->Split(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SplitRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int SplitCPUKernel::Run() { + int ret = LiteBackendParallelLaunch(SplitRun, this, thread_n_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Scale error error_code[" << ret << "]"; + return RET_ERROR; + } + + return RET_OK; +} + +kernel::LiteKernel *CpuSplitFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Split); + auto *kernel = new (std::nothrow) SplitCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "New kernel fails."; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Split, CpuSplitFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/split.h b/mindspore/lite/src/runtime/kernel/arm/fp32/split.h new file mode 100644 index 00000000000..7129f4fb779 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/split.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPLIT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPLIT_H_ + +#include + +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class SplitCPUKernel : public LiteKernel { + public: + SplitCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), thread_num_(ctx->threadNum) {} + ~SplitCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int Split(int task_id); + + private: + int thread_num_; + int thread_n_stride_; + int thread_n_num_; + int num_unit_; + int unit_size_; + float *input_ptr_; + std::vector output_ptr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPLIT_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.cc new file mode 100644 index 00000000000..09b24b22720 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/squeeze.h" +#include +#include "src/runtime/kernel/arm/opclib/squeeze.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Squeeze; + +namespace mindspore::kernel { +namespace { + constexpr int kSqueezeInputNum = 1; + constexpr int kSqueezeOutputNum = 1; +} // namespace + +int SqueezeCPUKernel::Init() { return RET_OK; } + +int SqueezeCPUKernel::ReSize() { return RET_OK; } + +int SqueezeCPUKernel::Run() { + auto input_ptr = reinterpret_cast(inputs_.front()->Data()); + auto output_ptr = reinterpret_cast(outputs_.front()->Data()); + size_t data_size = inputs_.front()->Size(); + auto ret = DoSqueeze(input_ptr, output_ptr, data_size); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Do squeeze failed."; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuSqueezeFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_Squeeze); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "desc type is not Squeeze"; + return nullptr; + } + auto *kernel = new (std::nothrow) SqueezeCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "New kernel fails."; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Squeeze, CpuSqueezeFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.h b/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.h new file mode 100644 index 00000000000..e48fce51c95 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SQUEEZE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SQUEEZE_H_ + +#include +#include "src/lite_kernel.h" + + +namespace mindspore::kernel { + +class SqueezeCPUKernel : public LiteKernel { + public: + explicit SqueezeCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~SqueezeCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + std::vector axes_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SQUEEZE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/stack.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/stack.cc new file mode 100644 index 00000000000..1e87cde2b3d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/stack.cc @@ -0,0 +1,113 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/stack.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/opclib/fp32/stack.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Stack; + +namespace mindspore::kernel { +int StackCPUKernel::Init() { + StackParameter *param = reinterpret_cast(opParameter); + auto input0_shape = inputs_[0]->shape(); + axis_ = param->axis_ < 0 ? param->axis_ + input0_shape.size() : param->axis_; + schema::Format input0_format = inputs_[0]->GetFormat(); + bool need_convert_format = false; + for (size_t i = 1; i < inputs_.size(); ++i) { + if (inputs_[i]->GetFormat() != input0_format) { + need_convert_format = true; + } + } + if (!need_convert_format) { + outputs_[0]->SetFormat(input0_format); + return RET_OK; + } + + for (size_t i = 0; i < inputs_.size(); ++i) { + if (inputs_[i]->GetFormat() != schema::Format_NHWC) { + convert_functions_[i] = LayoutTransform(inputs_[i]->data_type(), inputs_[i]->GetFormat(), schema::Format_NHWC); + if (convert_functions_[i] == nullptr) { + MS_LOG(ERROR) << "Can not convert format " << inputs_[i]->GetFormat() << " to " << schema::Format_NHWC; + return RET_ERROR; + } + size_t packed_input_size = + inputs_[i]->Channel() * inputs_[i]->Batch() * inputs_[i]->Height() * inputs_[i]->Width(); + packed_inputs_[i] = reinterpret_cast(malloc(packed_input_size * sizeof(float))); + if (packed_inputs_[i] == nullptr) { + MS_LOG(ERROR) << "malloc memory fail!"; + return RET_ERROR; + } + memset(packed_inputs_[i], 0, packed_input_size * sizeof(float)); + } else { + convert_functions_[i] = nullptr; + packed_inputs_[i] = nullptr; + } + } + outputs_[0]->SetFormat(schema::Format_NHWC); + return RET_OK; +} + +int StackCPUKernel::Run() { + size_t inputs_num = inputs_.size(); + auto input0_shape = inputs_[0]->shape(); + auto *output_data = reinterpret_cast(outputs_[0]->Data()); + float *inputs[inputs_num]; + for (size_t i = 0; i < inputs_num; ++i) { + inputs[i] = reinterpret_cast(inputs_[i]->Data()); + if (convert_functions_[i] != nullptr) { + convert_functions_[i](inputs[i], packed_inputs_[i], inputs_[i]->Batch(), + inputs_[i]->Height() * inputs_[i]->Width(), inputs_[i]->Channel()); + } else { + packed_inputs_[i] = inputs[i]; + } + } + DoStack(packed_inputs_.data(), inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data); + return RET_OK; +} + +kernel::LiteKernel *CpuStackFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "Input opParameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) StackCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new StackCPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Stack, CpuStackFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/stack.h b/mindspore/lite/src/runtime/kernel/arm/fp32/stack.h new file mode 100644 index 00000000000..c1d76ca1937 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/stack.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_STACK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_STACK_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/base/layout_transform.h" + +namespace mindspore::kernel { +class StackCPUKernel : public LiteKernel { + public: + StackCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs), + convert_functions_(inputs_.size(), nullptr), + packed_inputs_(inputs_.size(), nullptr) {} + + ~StackCPUKernel() { + for (size_t i = 0; i < packed_inputs_.size(); ++i) { + if (packed_inputs_[i] != nullptr) { + free(packed_inputs_[i]); + packed_inputs_[i] = nullptr; + } + } + } + + int Init() override; + int ReSize() override { return 0; } + int Run() override; + + private: + int axis_; + std::vector convert_functions_; + std::vector packed_inputs_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_STACK_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/strided_slice.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/strided_slice.cc new file mode 100644 index 00000000000..112fe294742 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/strided_slice.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/strided_slice.h" +#include +#include "src/runtime/kernel/arm/opclib/fp32/strided_slice.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_StridedSlice; + +namespace mindspore::kernel { + +int StridedSliceCPUKernel::Init() { return RET_OK; } + +int StridedSliceCPUKernel::ReSize() { return 0; } + +int StridedSliceCPUKernel::StridedSlice() { + StridedSliceParameter *param = reinterpret_cast(opParameter); + auto ret = DoStridedSlice(input_ptr_, output_ptr_, param); + if (ret != RET_OK) { + return RET_ERROR; + } + return RET_OK; +} + +int StridedSliceCPUKernel::Run() { + auto input_tensor = inputs_.at(0); + auto output_tensor = outputs_.at(0); + input_ptr_ = reinterpret_cast(input_tensor->Data()); + output_ptr_ = reinterpret_cast(output_tensor->Data()); + + auto ret = StridedSlice(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "StridedSlice error error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuStridedSliceFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_StridedSlice); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "opParameter null pointer dereferencing."; + return nullptr; + } + auto *kernel = new (std::nothrow) StridedSliceCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "New kernel fails."; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != 0) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_StridedSlice, CpuStridedSliceFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/strided_slice.h b/mindspore/lite/src/runtime/kernel/arm/fp32/strided_slice.h new file mode 100644 index 00000000000..ae2aabe51a6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/strided_slice.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_FP32_STRIDED_SLICE_H_ +#define MINDSPORE_LITE_SRC_BACKEND_ARM_FP32_STRIDED_SLICE_H_ + +#include +#include "ir/anf.h" +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class StridedSliceCPUKernel : public LiteKernel { + public: + StridedSliceCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), thread_num_(ctx->threadNum) {} + ~StridedSliceCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int StridedSlice(); + + private: + int thread_num_; + float *input_ptr_; + float *output_ptr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_FP32_STRIDED_SLICE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tile.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tile.cc new file mode 100644 index 00000000000..52fc717501f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tile.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/tile.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Tile; + +namespace mindspore::kernel { +int TileCPUKernel::Init() { + auto tile_parameter_ = reinterpret_cast(opParameter); + for (int i = 0; i < tile_parameter_->in_dim_; ++i) { + tile_parameter_->in_shape_[i] = inputs_[0]->shape()[i]; + tile_parameter_->out_shape_[i] = outputs_[0]->shape()[i]; + } + ComputeStrides(tile_parameter_->in_shape_, tile_parameter_->in_strides_, tile_parameter_->in_dim_); + ComputeStrides(tile_parameter_->out_shape_, tile_parameter_->out_strides_, tile_parameter_->in_dim_); + return RET_OK; +} + +void TileCPUKernel::ComputeStrides(int *shape, int *strides, int ndim) { + int stride = 1; + for (int i = ndim - 1; i >= 0; i--) { + strides[i] = stride; + stride *= shape[i]; + } +} + +int TileCPUKernel::ReSize() { return RET_OK; } + +int TileCPUKernel::Run() { + auto input_addr = reinterpret_cast(inputs_.at(0)->Data()); + auto output_addr = reinterpret_cast(outputs_.at(0)->Data()); + + Tile(input_addr, output_addr, reinterpret_cast(opParameter)); + return RET_OK; +} + +kernel::LiteKernel *CpuTileFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::Context *ctx, const KernelKey &desc) { + if (parameter == nullptr || ctx == nullptr) { + MS_LOG(ERROR) << "parameter or ctx is nullptr"; + return nullptr; + } + MS_ASSERT(desc.type == PrimitiveType_Tile); + auto *kernel = new (std::nothrow) TileCPUKernel(parameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Tile, CpuTileFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tile.h b/mindspore/lite/src/runtime/kernel/arm/fp32/tile.h new file mode 100644 index 00000000000..b5457802546 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tile.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TILE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TILE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/tile.h" + +namespace mindspore::kernel { +class TileCPUKernel : public LiteKernel { + public: + explicit TileCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~TileCPUKernel() override {} + + int Init() override; + int ReSize() override; + int Run() override; + + private: + void ComputeStrides(int *shape, int *strides, int ndim); +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TILE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc new file mode 100644 index 00000000000..b181e9f97b1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/topk.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_TopK; + +namespace mindspore::kernel { +int TopKCPUKernel::Init() { + lite::tensor::Tensor *input = inputs_.at(0); + topk_parameter_->last_dim_size_ = input->shape()[input->shape().size() - 1]; + topk_parameter_->loop_num_ = 1; + for (int i = 0; i < input->shape().size() - 1; ++i) { + topk_parameter_->loop_num_ *= input->shape()[i]; + } + return RET_OK; +} + +int TopKCPUKernel::ReSize() { return RET_OK; } + +int TopKCPUKernel::Run() { + auto input_data = reinterpret_cast(inputs_.at(0)->Data()); + auto output_data = reinterpret_cast(outputs_.at(0)->Data()); + auto output_index = reinterpret_cast(outputs_.at(1)->Data()); + + Node *top_map = reinterpret_cast(malloc(sizeof(Node) * topk_parameter_->last_dim_size_)); + MS_EXCEPTION_IF_NULL(top_map); + topk_parameter_->topk_node_list_ = top_map; + Topk(input_data, output_data, output_index, topk_parameter_); + free(top_map); + topk_parameter_->topk_node_list_ = nullptr; + return RET_OK; +} + +kernel::LiteKernel *CpuTopKFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::Context *ctx, const KernelKey &desc) { + MS_EXCEPTION_IF_NULL(parameter); + MS_ASSERT(desc.type == PrimitiveType_Tile); + auto *kernel = new (std::nothrow) TopKCPUKernel(parameter, inputs, outputs); + MS_EXCEPTION_IF_NULL(kernel); + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_TopK, CpuTopKFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/topk.h b/mindspore/lite/src/runtime/kernel/arm/fp32/topk.h new file mode 100644 index 00000000000..cfee3826577 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/topk.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TOPK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TOPK_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/topk.h" + +namespace mindspore::kernel { +class TopKCPUKernel : public LiteKernel { + public: + explicit TopKCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) { + topk_parameter_ = reinterpret_cast(parameter); + } + ~TopKCPUKernel() override {} + + int Init() override; + int ReSize() override; + int Run() override; + + private: + TopkParameter *topk_parameter_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TOPK_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc new file mode 100644 index 00000000000..7f794847ba2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/transpose.h" +#include +#include "src/runtime/kernel/arm/opclib/transpose.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Transpose; + +namespace mindspore::kernel { +namespace { + constexpr int kTransposeInputNum = 1; + constexpr int kTransposeOutputNum = 1; +} // namespace +int TransposeCPUKernel::Init() { + auto &inTensor = inputs_.front(); + auto &outTensor = outputs_.front(); + auto param = reinterpret_cast(opParameter); + auto in_shape = inTensor->shape(); + auto out_shape = outTensor->shape(); + param->strides_[param->num_axes_ - 1] = 1; + param->out_strides_[param->num_axes_ - 1] = 1; + param->data_size_ = inTensor->Size(); + for (int i = param->num_axes_ - 2; i >= 0; i--) { + param->strides_[i] = in_shape[i + 1] * param->strides_[i + 1]; + param->out_strides_[i] = out_shape[i + 1] * param->out_strides_[i + 1]; + } + return RET_OK; +} + +int TransposeCPUKernel::ReSize() { return RET_OK; } + +int TransposeCPUKernel::Run() { + MS_ASSERT(inputs_.size() == TransposeInputNum); + MS_ASSERT(outputs_.size() == TransposeOutputNum); + auto &inTensor = inputs_.front(); + auto &outTensor = outputs_.front(); + if (inTensor == nullptr || outTensor == nullptr) { + MS_LOG(ERROR) << "null pointer dreferencing."; + return RET_ERROR; + } + auto *in_data = static_cast(inTensor->Data()); + auto *out_data = static_cast(outTensor->Data()); + auto in_shape = inTensor->shape(); + auto out_shape = outTensor->shape(); + auto *input_shape = &in_shape.front(); + auto *output_shape = &out_shape.front(); + + auto ret = + DoTranspose(in_data, out_data, input_shape, output_shape, reinterpret_cast(opParameter)); + return ret; +} + +kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(desc.type == schema::PrimitiveType_Transpose); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "desc type is not Transpose"; + return nullptr; + } + auto *kernel = new (std::nothrow) TransposeCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "New kernel fails."; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Transpose, CpuTransposeFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.h b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.h new file mode 100644 index 00000000000..f13ba70015b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ARM_FP32_TRANSPOSE_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_ARM_FP32_TRANSPOSE_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/kernel_factory.h" + + +namespace mindspore::kernel { + +class TransposeCPUKernel : public LiteKernel { + public: + explicit TransposeCPUKernel(OpParameter *param, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(param, inputs, outputs) {} + ~TransposeCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_ARM_FP32_TRANSPOSE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unique.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/unique.cc new file mode 100644 index 00000000000..7aaa3686780 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unique.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/unique.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Unique; + +namespace mindspore::kernel { +int UniqueCPUKernel::Init() { return RET_OK; } + +int UniqueCPUKernel::ReSize() { return RET_OK; } + +int UniqueCPUKernel::Run() { + auto input = reinterpret_cast(inputs_.at(0)->Data()); + auto output0 = reinterpret_cast(outputs_.at(0)->Data()); + auto output1 = reinterpret_cast(outputs_.at(1)->Data()); + + int output0_len = 0; + Unique(input, inputs_.at(0)->ElementsNum(), output0, &output0_len, output1); + + std::vector out_shape = outputs_.at(0)->shape(); + out_shape[out_shape.size() - 1] = output0_len; + outputs_.at(0)->set_shape(out_shape); + return RET_OK; +} + +kernel::LiteKernel *CpuUniqueFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *parameter, const lite::Context *ctx, + const KernelKey &desc) { + MS_ASSERT(parameter); + MS_ASSERT(desc.type == PrimitiveType_Unique); + auto *kernel = new (std::nothrow) UniqueCPUKernel(parameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Unique, CpuUniqueFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unique.h b/mindspore/lite/src/runtime/kernel/arm/fp32/unique.h new file mode 100644 index 00000000000..320c1bd0d65 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unique.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNIQUE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNIQUE_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/unique.h" + +namespace mindspore::kernel { +class UniqueCPUKernel : public LiteKernel { + public: + UniqueCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~UniqueCPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNIQUE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.cc new file mode 100644 index 00000000000..599acd84e2e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/unsqueeze.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Unsqueeze; + +namespace mindspore::kernel { +int UnsqueezeCPUKernel::Init() { + int ret = ReSize(); + return ret; +} + +int UnsqueezeCPUKernel::ReSize() { + data_size_ = inputs_.at(0)->ElementsNum(); + thread_sz_count_ = MSMIN(thread_count_, data_size_); + thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); + return RET_OK; +} + +int UnsqueezeCPUKernel::DoUnsqueeze(int task_id) { + size_t size = MSMIN(thread_sz_stride_, data_size_ - task_id * thread_sz_stride_); + if (size == 0) { + return RET_OK; + } + int offset = task_id * thread_sz_stride_; + int ret = Unsqueeze(in_ptr_ + offset, out_ptr_ + offset, size * sizeof(float)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "UnsqueezeRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int UnsqueezeRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto g_kernel = reinterpret_cast(cdata); + auto ret = g_kernel->DoUnsqueeze(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "UnsqueezeRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +int UnsqueezeCPUKernel::Run() { + in_ptr_ = reinterpret_cast(inputs_.at(0)->Data()); + out_ptr_ = reinterpret_cast(outputs_.at(0)->Data()); + int ret = LiteBackendParallelLaunch(UnsqueezeRun, this, thread_sz_count_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "UnsqueezeRun error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +kernel::LiteKernel *CpuUnsqueezeFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Unsqueeze); + auto *kernel = new (std::nothrow) UnsqueezeCPUKernel(opParameter, inputs, outputs, ctx); + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Unsqueeze, CpuUnsqueezeFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.h b/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.h new file mode 100644 index 00000000000..1bc38389b26 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNSQUEEZE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNSQUEEZE_H_ + +#include +#include "src/lite_kernel.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/opclib/fp32/unsqueeze.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class UnsqueezeCPUKernel : public LiteKernel { + public: + UnsqueezeCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) {} + ~UnsqueezeCPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoUnsqueeze(int task_id); + + private: + int thread_count_; + int thread_sz_count_; + int thread_sz_stride_; + int data_size_; + float *in_ptr_; + float *out_ptr_; + const Context *ctx_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNSQUEEZE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.cc new file mode 100644 index 00000000000..34854891ebc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/unstack.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Unstack; + +namespace mindspore::kernel { +int UnstackCPUKernel::Init() { + auto input = inputs_.at(0); + MS_ASSERT(input != nullptr); + size_t shape_size = input->shape().size(); + + auto para = reinterpret_cast(opParameter); + para->pre_dims_ = 1; + para->axis_dim_ = 1; + para->after_dims_ = 1; + if (para->axis_ < 0) { + para->axis_ += shape_size; + } + for (size_t i = 0; i < shape_size; i++) { + if (i < para->axis_) { + para->pre_dims_ *= input->DimensionSize(i); + } else if (i > para->axis_) { + para->after_dims_ *= input->DimensionSize(i); + } else { + para->axis_dim_ = input->DimensionSize(i); + } + } + + output_addr_array_ = reinterpret_cast(malloc(sizeof(float *) * outputs_.size())); + if (output_addr_array_ == nullptr) { + MS_LOG(ERROR) << "Failed to malloc memory"; + return lite::RET_ERROR; + } + return RET_OK; +} + +int UnstackCPUKernel::ReSize() { return RET_OK; } + +int UnstackCPUKernel::Run() { + float *input = reinterpret_cast(inputs_.at(0)->Data()); + size_t out_num = outputs_.size(); + for (size_t i = 0; i < out_num; i++) { + output_addr_array_[i] = reinterpret_cast(outputs_.at(i)->Data()); + } + Unistack(input, output_addr_array_, reinterpret_cast(opParameter)); + return RET_OK; +} + +kernel::LiteKernel *CpuUnstackFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *parameter, const lite::Context *ctx, + const KernelKey &desc) { + MS_ASSERT(parameter != nullptr); + MS_ASSERT(desc.type == PrimitiveType_Unstack); + auto *kernel = new (std::nothrow) UnstackCPUKernel(parameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Unstack, CpuUnstackFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.h b/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.h new file mode 100644 index 00000000000..cbc9abdbdd6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNSTACK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNSTACK_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/unstack.h" + +namespace mindspore::kernel { +class UnstackCPUKernel : public LiteKernel { + public: + UnstackCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~UnstackCPUKernel() { + free(output_addr_array_); + } + + int Init() override; + int ReSize() override; + int Run() override; + + private: + float **output_addr_array_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNSTACK_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/where.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/where.cc new file mode 100644 index 00000000000..ebb01f54ea0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/where.cc @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/where.h" +#include +#include "schema/model_generated.h" +#include "src/runtime/kernel/arm/opclib/where.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Where; + +namespace mindspore::kernel { +int WhereCPUKernel::Init() { + where_param_->op_parameter_.thread_num_ = thread_count_; + return RET_OK; +} + +int WhereCPUKernel::DoExcute(int task_id) { + Where(input_data, input_data1, input_data2, output_data, where_param_, task_id); + return RET_OK; +} + +int WhereRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto wheredata = reinterpret_cast(cdata); + auto ret = wheredata->DoExcute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "WhereRun error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} +int WhereCPUKernel::Run() { + auto input = inputs_.at(0); + auto input1 = inputs_.at(1); + auto input2 = inputs_.at(2); + int num = input->ElementsNum(); + int num1_ = input1->ElementsNum(); + int num2_ = input2->ElementsNum(); + + input_data = reinterpret_cast(input->Data()); + input_data1 = reinterpret_cast(input1->Data()); + input_data2 = reinterpret_cast(input2->Data()); + output_data = reinterpret_cast(outputs_.at(0)->Data()); + int num_max = num > num1_ ? num : (num1_ > num2_ ? num1_ : num2_); + where_param_->num_ = num; + where_param_->num1_ = num1_; + where_param_->num2_ = num2_; + where_param_->number_ = num_max; + + if (((num != 1) && (num != num_max)) || ((num1_ != 1) && (num1_ != num_max)) || + ((num2_ != 1) && (num2_ != num_max))) { + MS_LOG(ERROR) << "The length of three inputs are not equal to 1 or length of output, which is unacceptable"; + return RET_ERROR; + } + if (num_max <= 0) { + MS_LOG(ERROR) << "Error, inputs' length are zero !!!"; + return RET_ERROR; + } + auto ret = LiteBackendParallelLaunch(WhereRun, this, where_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "WhereDwRun error: error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +kernel::LiteKernel *CpuWhereFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "input opParameter is nullptr!"; + return nullptr; + } + + auto *kernel = new (std::nothrow) WhereCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new WhereCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_Where, CpuWhereFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/where.h b/mindspore/lite/src/runtime/kernel/arm/fp32/where.h new file mode 100644 index 00000000000..36d1559077c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/where.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_WHERE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_WHERE_H_ + +#include +#include "src/lite_kernel.h" + +#include "include/context.h" +#include "src/runtime/kernel/arm/opclib/where.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class WhereCPUKernel : public LiteKernel { + public: + WhereCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->threadNum) { + where_param_ = reinterpret_cast(opParameter); + } + ~WhereCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override; + int DoExcute(int task_id); + + protected: + int thread_count_; + const Context *ctx_; + WhereParameter *where_param_; + + private: + bool *input_data; + float *input_data1; + float *input_data2; + float *output_data; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_WHERE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.cc new file mode 100644 index 00000000000..8f12d39eb5a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp32/zeroslike.h" +#include +#include "schema/model_generated.h" +#include "src/runtime/kernel/arm/opclib/zeroslike.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_ZerosLike; + +namespace mindspore::kernel { +constexpr int kInputNum = 1; +constexpr int kOutputNum = 1; + +int ZerosLikeCPUKernel::Init() { return RET_OK; } + +int ZerosLikeCPUKernel::Run() { + auto input = inputs_.at(0); + auto input_data = reinterpret_cast(input->Data()); + auto output_data = reinterpret_cast(outputs_.at(0)->Data()); + ApproximateZerosLike(input_data, output_data, input->ElementsNum()); + return RET_OK; +} + +kernel::LiteKernel *CpuZerosLikeFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + if (opParameter == nullptr) { + MS_LOG(ERROR) << "input opParameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) ZerosLikeCPUKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new ZerosLikeCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, PrimitiveType_ZerosLike, CpuZerosLikeFp32KernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.h b/mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.h new file mode 100644 index 00000000000..cd6aad26679 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ZEROSLIKE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ZEROSLIKE_H_ + +#include +#include "src/lite_kernel.h" + + +namespace mindspore::kernel { +class ZerosLikeCPUKernel : public LiteKernel { + public: + ZerosLikeCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + + ~ZerosLikeCPUKernel() = default; + + int Init() override; + int ReSize() override { return 0; } + int Run() override; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ZEROSLIKE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc new file mode 100644 index 00000000000..a03c8c049b2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc @@ -0,0 +1,146 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/add_int8.h" +#include +#include +#include "src/runtime/kernel/arm/opclib/arithmetic_common.h" +#include "src/runtime/runtime_api.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_Add; + +namespace mindspore::kernel { +int QuantizedAddCPUKernel::Init() { + lite::tensor::Tensor *input0 = inputs_.at(0); + lite::tensor::Tensor *input1 = inputs_.at(1); + lite::tensor::Tensor *output = outputs_.at(0); + MS_ASSERT(input0); + MS_ASSERT(input1); + MS_ASSERT(output); + + para_.input0_scale_ = input0->GetQuantParams().front().scale; + para_.input0_offset_ = input0->GetQuantParams().front().zeroPoint * -1; + para_.input1_scale_ = input1->GetQuantParams().front().scale; + para_.input1_offset_ = input1->GetQuantParams().front().zeroPoint * -1; + para_.output_scale_ = output->GetQuantParams().front().scale; + para_.output_offset_ = output->GetQuantParams().front().zeroPoint; + + const int left_shift = 20; // 1 << 20, 2/20 + const double twice_max_input_scale = 2 * std::max(para_.input0_scale_, para_.input1_scale_); + const double real_input0_multiplier = para_.input0_scale_ / twice_max_input_scale; + const double real_input1_multiplier = para_.input1_scale_ / twice_max_input_scale; + const double real_output_multiplier = twice_max_input_scale / ((1 << left_shift) * para_.output_scale_); + + QuantizeMultiplierSmallerThanOne(real_input0_multiplier, ¶_.input0_multiplier_, ¶_.input0_shift_); + QuantizeMultiplierSmallerThanOne(real_input1_multiplier, ¶_.input1_multiplier_, ¶_.input1_shift_); + QuantizeMultiplierSmallerThanOne(real_output_multiplier, ¶_.output_multiplier_, ¶_.output_shift_); + + para_.output_activation_min_ = std::numeric_limits::min(); + para_.output_activation_max_ = std::numeric_limits::max(); + + int left_shift0 = -para_.input0_shift_ > 0 ? -para_.input0_shift_ : 0; + para_.right_shift0_ = -para_.input0_shift_ > 0 ? 0 : para_.input0_shift_; + + int left_shift1 = -para_.input1_shift_ > 0 ? -para_.input1_shift_ : 0; + para_.right_shift1_ = -para_.input1_shift_ > 0 ? 0 : para_.input1_shift_; + + para_.left_shift_out_ = -para_.output_shift_ > 0 ? -para_.output_shift_ : 0; + para_.right_shift_out_ = -para_.output_shift_ > 0 ? 0 : para_.output_shift_; + + para_.left_shift_result0_ = (1 << left_shift) * ((1 << left_shift0)); + para_.left_shift_result1_ = (1 << left_shift) * ((1 << left_shift1)); + + MS_ASSERT(left_shift + left_shift0 == left_shift); + MS_ASSERT(left_shift + left_shift1 == left_shift); + return 0; +} + +int QuantizedAddCPUKernel::ReSize() { return 0; } + +int QuantizedAddCPUKernel::Run() { + input0_data_ = static_cast(inputs_.at(0)->Data()); + input1_data_ = static_cast(inputs_.at(1)->Data()); + output_data_ = static_cast(outputs_.at(0)->Data()); + + elements_num_ = inputs_.at(0)->ElementsNum(); + count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_; + + if (inputs_.at(0)->ElementsNum() != inputs_.at(1)->ElementsNum()) { + input0_data_ = static_cast(ctx_->allocator->Malloc(outputs_.at(0)->Size())); + input1_data_ = static_cast(ctx_->allocator->Malloc(outputs_.at(0)->Size())); + + ArithmeticParameter tile_para = {0}; + tile_para.ndim_ = outputs_.at(0)->shape().size(); + for (size_t i = 0; i < tile_para.ndim_; i++) { + tile_para.in_shape0_[i] = inputs_.at(0)->DimensionSize(i); + tile_para.in_shape1_[i] = inputs_.at(1)->DimensionSize(i); + tile_para.out_shape_[i] = outputs_.at(0)->DimensionSize(i); + } + TileDimensionsUint8(static_cast(inputs_.at(0)->Data()), static_cast(inputs_.at(1)->Data()), + reinterpret_cast(input0_data_), reinterpret_cast(input1_data_), + &tile_para); + auto ret = LiteBackendParallelLaunch(AddInt8Run, this, thread_count_); + ctx_->allocator->Free(input0_data_); + ctx_->allocator->Free(input1_data_); + return ret; + } + + auto ret = LiteBackendParallelLaunch(AddInt8Run, this, thread_count_); + return ret; +} + +int AddInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto add = reinterpret_cast(cdata); + add->DoExecute(task_id); + return lite::RET_OK; +} + +int QuantizedAddCPUKernel::DoExecute(int tId) { + int64_t real_dst_count = MSMIN(elements_num_ - tId * count_unit_, count_unit_); + int8_t *cur_input0_data = input0_data_ + tId * count_unit_; + int8_t *cur_input1_data = input1_data_ + tId * count_unit_; + int8_t *cur_output_data = output_data_ + tId * count_unit_; + + AddInt8(cur_input0_data, cur_input1_data, cur_output_data, real_dst_count, ¶_); + return lite::RET_OK; +} + +kernel::LiteKernel *CpuAddInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::Context *ctx, const KernelKey &desc) { + if (parameter == nullptr || ctx == nullptr) { + MS_LOG(ERROR) << "parameter or ctx is nullptr"; + return nullptr; + } + MS_ASSERT(desc.type == PrimitiveType_Add); + auto *kernel = new (std::nothrow) QuantizedAddCPUKernel(parameter, inputs, outputs, ctx); + MS_EXCEPTION_IF_NULL(kernel); + + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h new file mode 100644 index 00000000000..ed800fb8b1b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ADD_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ADD_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/add_int8.h" +#include "src/runtime/runtime_api.h" + +namespace mindspore::kernel { +class QuantizedAddCPUKernel : public LiteKernel { + public: + explicit QuantizedAddCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx_->threadNum) {} + ~QuantizedAddCPUKernel() override {} + + int Init() override; + int ReSize() override; + int Run() override; + int DoExecute(int tId); + + private: + const lite::Context *ctx_; + AddQuantParameter para_; + int thread_count_; + int64_t elements_num_; + int64_t count_unit_; + int8_t *input0_data_ = nullptr; + int8_t *input1_data_ = nullptr; + int8_t *output_data_ = nullptr; +}; + +int AddInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ADD_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.cc new file mode 100644 index 00000000000..26237788eb0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/bias_add_int8.h" +#include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h" +#include "src/runtime/kernel/arm/opclib/errorcode.h" +#include "src/kernel_registry.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_BiasAdd; + +namespace mindspore::kernel { +int BiasAddInt8CPUKernel::Init() { + auto bias_param = reinterpret_cast(opParameter); + auto dims = inputs_[0]->shape(); + bias_param->ndim_ = dims.size(); + for (int i = 0; i < bias_param->ndim_; i++) { + bias_param->in_shape0_[i] = dims[i]; + bias_param->in_shape1_[i] = 1; + bias_param->out_shape_[i] = dims[i]; + } + bias_param->in_shape1_[3] = dims[3]; + return OPCLIB_OK; +} + +int BiasAddInt8CPUKernel::ReSize() { return OPCLIB_OK; } + +int BiasAddInt8CPUKernel::Run() { + auto in = reinterpret_cast(inputs_.at(0)->Data()); + auto bias = reinterpret_cast(inputs_.at(1)->Data()); + auto out = reinterpret_cast(outputs_.at(0)->Data()); + size_t data_size = inputs_.at(0)->ElementsNum(); + auto tile_in = static_cast(ctx_->allocator->Malloc(data_size)); + auto tile_bias = static_cast(ctx_->allocator->Malloc(data_size)); + if (tile_in == nullptr || tile_bias == nullptr) { + MS_LOG(ERROR) << "Failed to malloc momery"; + return OPCLIB_ERR; + } + BroadcastAddInt8(in, bias, tile_in, tile_bias, out, data_size, reinterpret_cast(opParameter)); + ctx_->allocator->Free(tile_in); + ctx_->allocator->Free(tile_bias); + return OPCLIB_OK; +} + +kernel::LiteKernel *CpuBiasAddInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *parameter, const lite::Context *ctx, + const KernelKey &desc) { + if (parameter == nullptr || ctx == nullptr) { + MS_LOG(ERROR) << "parameter or context is nullptr"; + return nullptr; + } + MS_ASSERT(desc.type == PrimitiveType_BiasAdd); + auto *kernel = new (std::nothrow) BiasAddInt8CPUKernel(parameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + return nullptr; + } + + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.h new file mode 100644 index 00000000000..5be829958d9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_BAIS_ADD_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_BAIS_ADD_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/unique.h" +#include "src/runtime/kernel/arm/opclib/arithmetic_common.h" + +namespace mindspore::kernel { +class BiasAddInt8CPUKernel : public LiteKernel { + public: + BiasAddInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx) {} + ~BiasAddInt8CPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + const lite::Context *ctx_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_BAIS_ADD_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.cc new file mode 100644 index 00000000000..ae284550e18 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.cc @@ -0,0 +1,144 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/concat_int8.h" +#include "src/runtime/kernel/arm/opclib/int8/concat_int8.h" +#include "schema/model_generated.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { + +int ConcatInt8CPUKernel::Init() { + ConcatBaseCPUKernel::Init(); + quant_concat_parm_ = concat_param_->concat_quant_arg_; + quant_concat_parm_ = new (std::nothrow) ConcatQuantArg; + auto input_num = inputs_.size(); + quant_concat_parm_->input_num_ = input_num; + quant_concat_parm_->input_sizes_ = reinterpret_cast(malloc(sizeof(int) * input_num)); + if (quant_concat_parm_->input_sizes_ == nullptr) { + MS_LOG(ERROR) << "Null pointer reference: quant_concat_parm_->input_sizes_."; + return RET_ERROR; + } + + for (size_t i = 0; i < input_num; i++) { + quant_concat_parm_->input_sizes_[i] = 1; + } + quant_concat_parm_->input_shapes_ = reinterpret_cast(malloc(sizeof(int *) * input_num)); + if (quant_concat_parm_->input_shapes_ == nullptr) { + MS_LOG(ERROR) << "Null pointer reference: quant_concat_parm_->input_shapes_."; + return RET_ERROR; + } + + for (size_t i = 0; i < input_num; i++) { + auto *input_tensor = inputs_.at(i); + MS_ASSERT(input_tensor != nullptr); + auto input_size = input_tensor->shape().size(); + MS_ASSERT(input_size != NULL); + quant_concat_parm_->input_shapes_[i] = reinterpret_cast(malloc(sizeof(int) * input_size)); + if (quant_concat_parm_->input_shapes_[i] == nullptr) { + MS_LOG(ERROR) << "Null pointer reference: quant_concat_parm_->input_shapes_[" << i << "]."; + return RET_ERROR; + } + + ::memcpy(quant_concat_parm_->input_shapes_[i], input_tensor->shape().data(), sizeof(int) * input_size); + for (size_t j = 0; j < input_size; j++) { + auto *input_tensor_tmp = inputs_.at(i); + auto input_shape = input_tensor_tmp->shape()[j]; + quant_concat_parm_->input_sizes_[i] *= input_shape; + } + } + + quant_concat_parm_->in_quant_args_ = reinterpret_cast(malloc(sizeof(QuantArg) * input_num)); + if (quant_concat_parm_->in_quant_args_ == nullptr) { + MS_LOG(ERROR) << "Null pointer reference: quant_concat_parm_->in_quant_args_."; + return RET_ERROR; + } + + for (size_t i = 0; i < input_num; i++) { + auto *input_tensor = inputs_.at(i); + auto quant_args = input_tensor->GetQuantParams(); + MS_ASSERT(quant_args.size() == 1); + quant_concat_parm_->in_quant_args_[i].scale_ = quant_args.front().scale; + quant_concat_parm_->in_quant_args_[i].zp_ = quant_args.front().zeroPoint; + } + + MS_ASSERT(outputs_.size() == 1); + auto output_tensor = outputs_.at(0); + MS_ASSERT(output_tensor != nullptr); + auto output_shape = output_tensor->shape(); + MS_ASSERT(output_shape != NULL); + auto output_dim = output_shape.size(); + quant_concat_parm_->output_dim_ = output_dim; + int output_size = 1; + for (size_t i = 0; i < output_dim; i++) { + output_size *= output_shape[i]; + } + quant_concat_parm_->output_size_ = output_size; + + quant_concat_parm_->output_shape_ = new int[output_size]; + ::memcpy(quant_concat_parm_->output_shape_, output_shape.data(), sizeof(int) * output_size); + + auto quant_args = output_tensor->GetQuantParams(); + MS_ASSERT(quant_args.size() == 1); + quant_concat_parm_->out_quant_args_.scale_ = quant_args.front().scale; + quant_concat_parm_->out_quant_args_.zp_ = quant_args.front().zeroPoint; + + return RET_OK; +} + +int ConcatInt8CPUKernel::ReSize() { return 0; } + +int ConcatInt8CPUKernel::Run() { + auto input_dim = quant_concat_parm_->input_num_; + int8_t **inputs_array = reinterpret_cast(malloc(sizeof(int8_t *) * input_dim)); + for (size_t i = 0; i < input_dim; i++) { + auto input_size = quant_concat_parm_->input_sizes_[i]; + inputs_array[i] = reinterpret_cast(malloc(sizeof(int8_t) * input_size)); + auto input_type = inputs_[i]->data_type(); + if (input_type == kNumberTypeUInt8) { + uint8_t *input_tmp = reinterpret_cast(inputs_[i]->Data()); + for (size_t j = 0; j < input_size; j++) { + inputs_array[i][j] = (int8_t)(input_tmp[j] - 128); + } + for (size_t j = 0; j < input_dim; j++) { + quant_concat_parm_->in_quant_args_[j].zp_ -= 128; + } + quant_concat_parm_->out_quant_args_.zp_ -= 128; + } else { + ::memcpy(inputs_array[i], inputs_.at(i)->Data(), sizeof(int8_t) * input_size); + } + } + int8_t *output_addr = reinterpret_cast(outputs_.at(0)->Data()); + Concat(inputs_array, output_addr, quant_concat_parm_, axis_); + auto output_type = outputs_[0]->data_type(); + if (output_type == kNumberTypeUInt8) { + auto output_size = quant_concat_parm_->output_size_; + for (size_t i = 0; i < output_size; i++) { + output_addr[i] = (uint8_t)(output_addr[i] + 128); + } + } + + for (int i = 0; i < input_dim; i++) { + free(*(inputs_array + i)); + } + return RET_OK; +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.h new file mode 100644 index 00000000000..192f50b46ce --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONCAT_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONCAT_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/base/concat_base.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class ConcatInt8CPUKernel : public ConcatBaseCPUKernel { + public: + ConcatInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConcatBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~ConcatInt8CPUKernel() override { delete quant_concat_parm_; } + + int Init() override; + int ReSize() override; + int Run() override; + + private: + ConcatQuantArg *quant_concat_parm_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONCAT_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.cc new file mode 100644 index 00000000000..9d4655d860b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.cc @@ -0,0 +1,241 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/convolution_3x3_int8.h" +#include "src/runtime/kernel/arm/opclib/int8/conv_int8.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { +void ProcessFilterUint8(int8_t *origin_weight, int16_t *dst_weight, ConvParameter *conv_param) { + auto input_channel = conv_param->input_channel_; + auto output_channel = conv_param->output_channel_; + auto kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + int iC8 = UP_DIV(input_channel, C8NUM); + + size_t tmp_size = output_channel * iC8 * C8NUM * kernel_plane * sizeof(int16_t); + auto tmp_addr = reinterpret_cast(malloc(tmp_size)); + memset(tmp_addr, 0, tmp_size); + PackWeightToC8Int8(origin_weight, tmp_addr, conv_param); + Conv3x3Int8FilterTransform(tmp_addr, dst_weight, iC8, output_channel, kernel_plane); + + free(tmp_addr); +} + +Convolution3x3Int8CPUKernel::~Convolution3x3Int8CPUKernel() { + if (transformed_filter_addr_ != nullptr) { + free(transformed_filter_addr_); + } + if (input_data_ != nullptr) { + free(input_data_); + } + if (tile_buffer_ != nullptr) { + free(tile_buffer_); + } + if (block_unit_buffer_ != nullptr) { + free(block_unit_buffer_); + } + if (tmp_dst_buffer_ != nullptr) { + free(tmp_dst_buffer_); + } + if (tmp_out_ != nullptr) { + free(tmp_out_); + } + FreeQuantParam(); +} + +int Convolution3x3Int8CPUKernel::InitWeightBias() { + auto input_channel = conv_param_->input_channel_; + auto output_channel = conv_param_->output_channel_; + int iC8 = UP_DIV(input_channel, C8NUM); + int oC4 = UP_DIV(output_channel, C4NUM); + // init weight + size_t transformed_size = iC8 * C8NUM * oC4 * C4NUM * 16 * sizeof(int16_t); + transformed_filter_addr_ = reinterpret_cast(malloc(transformed_size)); + if (transformed_filter_addr_ == nullptr) { + MS_LOG(ERROR) << "malloc transformed_filter_addr_ failed."; + return RET_ERROR; + } + memset(transformed_filter_addr_, 0, transformed_size); + auto weight_data = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + ProcessFilterUint8(weight_data, transformed_filter_addr_, conv_param_); + + // init bias + size_t new_bias_size = oC4 * C4NUM * sizeof(int32_t); + bias_data_ = reinterpret_cast(malloc(new_bias_size)); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "malloc bias_data_ failed."; + return RET_ERROR; + } + memset(bias_data_, 0, new_bias_size); + if (inputs_.size() == kInputSize2) { + auto ori_bias_addr = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias_addr, output_channel * sizeof(int32_t)); + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + return RET_OK; +} + +int Convolution3x3Int8CPUKernel::InitTmpBuffer() { + int ic8 = UP_DIV(conv_param_->input_channel_, C8NUM); + int oc4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int in_batch = conv_param_->input_batch_; + int input_w = conv_param_->input_w_; + int input_h = conv_param_->input_h_; + int output_batch = conv_param_->output_batch_; + int output_w = conv_param_->output_w_; + int output_h = conv_param_->output_h_; + + size_t tile_buffer_size = thread_count_ * TILE_NUM * 16 * ic8 * C8NUM * sizeof(int16_t); + tile_buffer_ = reinterpret_cast(malloc(tile_buffer_size)); + if (tile_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc tile_buffer_ failed."; + return RET_ERROR; + } + memset(tile_buffer_, 0, tile_buffer_size); + + size_t block_unit_buffer_size = thread_count_ * 4 * 4 * C8NUM * sizeof(int16_t); + block_unit_buffer_ = reinterpret_cast(malloc(block_unit_buffer_size)); + if (block_unit_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc block_unit_buffer_ failed."; + return RET_ERROR; + } + memset(block_unit_buffer_, 0, block_unit_buffer_size); + + size_t tmp_dst_buffer_size = thread_count_ * TILE_NUM * 16 * oc4 * C4NUM * sizeof(int32_t); + tmp_dst_buffer_ = reinterpret_cast(malloc(tmp_dst_buffer_size)); + if (tmp_dst_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_dst_buffer_ failed."; + return RET_ERROR; + } + memset(tmp_dst_buffer_, 0, tmp_dst_buffer_size); + + size_t tmp_out_size = oc4 * C4NUM * output_batch * output_w * output_h * sizeof(uint8_t); + tmp_out_ = reinterpret_cast(malloc(tmp_out_size)); + if (tmp_out_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_out_ failed."; + return RET_ERROR; + } + memset(tmp_out_, 0, tmp_out_size); + + size_t c8_input_size = in_batch * input_h * input_w * ic8 * C8NUM * sizeof(int16_t); + input_data_ = reinterpret_cast(malloc(c8_input_size)); + if (input_data_ == nullptr) { + MS_LOG(ERROR) << "malloc input_data_ failed."; + return RET_ERROR; + } + memset(input_data_, 0, c8_input_size); + return RET_OK; +} + +void Convolution3x3Int8CPUKernel::ConfigInputOutput() { + auto output_tensor = outputs_.at(kOutputIndex); + output_tensor->SetFormat(schema::Format_NHWC); +} + +int Convolution3x3Int8CPUKernel::Init() { + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + SetQuantParam(); + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + // init tmp input, output + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + // config input output + ConfigInputOutput(); + return RET_OK; +} + +int Convolution3x3Int8CPUKernel::ReSize() { + if (input_data_ != nullptr) { + free(input_data_); + } + if (tile_buffer_ != nullptr) { + free(tile_buffer_); + } + if (block_unit_buffer_ != nullptr) { + free(block_unit_buffer_); + } + if (tmp_dst_buffer_ != nullptr) { + free(tmp_dst_buffer_); + } + if (tmp_out_ != nullptr) { + free(tmp_out_); + } + + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + // init tmp input, output + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int Convolution3x3Int8CPUKernel::RunImpl(int task_id) { + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + Conv3x3Int8(input_data_, transformed_filter_addr_, reinterpret_cast(bias_data_), output_addr, tile_buffer_, + block_unit_buffer_, tmp_dst_buffer_, tmp_out_, task_id, conv_param_); + return RET_OK; +} + +int Convolution3x3Int8Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv = reinterpret_cast(cdata); + auto error_code = conv->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Convolution3x3 Int8 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int Convolution3x3Int8CPUKernel::Run() { + auto input_addr = reinterpret_cast(inputs_.at(kInputIndex)->Data()); + PackInputToC8Int8(input_addr, input_data_, conv_param_); + + int error_code = LiteBackendParallelLaunch(Convolution3x3Int8Impl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv3x3 int8 error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.h new file mode 100644 index 00000000000..555ed0522bf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_3X3_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_3X3_INT8_H_ + +#include +#include "src/lite_kernel.h" + +#include "src/runtime/kernel/arm/opclib/winograd_transform.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" + +namespace mindspore::kernel { +class Convolution3x3Int8CPUKernel : public ConvolutionBaseCPUKernel { + public: + Convolution3x3Int8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~Convolution3x3Int8CPUKernel() override; + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + int InitWeightBias(); + int InitTmpBuffer(); + void ConfigInputOutput(); + + private: + int16_t *transformed_filter_addr_; + int16_t *input_data_; + int16_t *tile_buffer_; + int16_t *block_unit_buffer_; + int32_t *tmp_dst_buffer_; + int8_t *tmp_out_; +}; +void ProcessFilterUint8(int8_t *origin_weight, int16_t *dst_weight, ConvParameter *conv_param); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_3X3_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc new file mode 100644 index 00000000000..cbd76d4dd98 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc @@ -0,0 +1,146 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/convolution_depthwise_int8.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/opclib/int8/conv_depthwise_int8.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +int ConvolutionDepthwiseInt8CPUKernel::InitWeightBias() { + // init weight, int8 -> int16 + // o, h, w, i -> o/8, h, w, i, 8; o == group, i == 1 + auto origin_weight = reinterpret_cast(inputs_[kWeightIndex]->Data()); + int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int pack_weight_size = C4NUM * OC4 * conv_param_->kernel_h_ * conv_param_->kernel_w_; + packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(int16_t))); + memset(packed_weight_, 0, pack_weight_size * sizeof(int16_t)); + PackDepthwiseInt8Weight(origin_weight, packed_weight_, conv_param_); + + // init bias, add output zp + bias_data_ = reinterpret_cast(malloc(C4NUM * OC4 * sizeof(int32_t))); + memset(bias_data_, 0, C4NUM * OC4 * sizeof(int32_t)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias, conv_param_->output_channel_ * sizeof(int32_t)); + } + return RET_OK; +} + +int ConvolutionDepthwiseInt8CPUKernel::Init() { + // conv base init + ConvolutionBaseCPUKernel::Init(); + + // init sliding window param + sliding = new SlidingWindowParam; + InitSlidingParam(sliding, conv_param_, C4NUM); + + // init quant param + ConvolutionBaseCPUKernel::SetQuantParam(); + + // init weight and bias + auto ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Depthwise int8 InitWeightBias error!"; + return ret; + } + + ret = ReSize(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Depthwise int8 ReSize error!"; + return ret; + } + return RET_OK; +} + +int ConvolutionDepthwiseInt8CPUKernel::ReSize() { + // malloc packed input buffer + int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * + UP_DIV(conv_param_->input_channel_, 4); + packed_input_ = reinterpret_cast(malloc(pack_input_size * sizeof(int16_t))); + memset(packed_input_, 0, pack_input_size * sizeof(int16_t)); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + + if (conv_param_->input_channel_ % C4NUM != 0) { + need_align_ = true; + int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * + (conv_param_->output_channel_, C4NUM); + packed_output_ = reinterpret_cast(malloc(pack_output_size * sizeof(int8_t))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_output_, 0, pack_output_size * sizeof(int8_t)); + } + return RET_OK; +} + +int ConvolutionDepthwiseInt8CPUKernel::Execute(int task_id) { + ConvDwInt8(packed_output_, packed_input_, packed_weight_, reinterpret_cast(bias_data_), conv_param_, + sliding, task_id); + return RET_OK; +} + +int ConvDwInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv_dw = reinterpret_cast(cdata); + auto ret = conv_dw->Execute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionDepthwiseInt8Run error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionDepthwiseInt8CPUKernel::Run() { + if (conv_param_->input_channel_ != conv_param_->output_channel_) { + MS_LOG(ERROR) << "Only support input channel equals output channel."; + return RET_ERROR; + } + + // pack input, assume input format: NHWC -> NHWC4 + auto input_tensor = inputs_.at(kInputIndex); + auto input_addr = reinterpret_cast(input_tensor->Data()); + PackDepthwiseInt8Input(input_addr, packed_input_, conv_param_); + + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + memset(output_addr, 0, outputs_.at(kOutputIndex)->ElementsNum() * sizeof(int8_t)); + if (!need_align_) { + packed_output_ = output_addr; + } + + auto ret = LiteBackendParallelLaunch(ConvDwInt8Run, this, conv_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvDwInt8Run error: error_code[" << ret << "]"; + return RET_ERROR; + } + + if (need_align_) { + PackNHWC4ToNHWCInt8(packed_output_, output_addr, conv_param_->output_batch_, + conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.h new file mode 100644 index 00000000000..2e9ad6fd392 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_DEPTHWISE_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_DEPTHWISE_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/opclib/fp32/conv_depthwise.h" + +namespace mindspore::kernel { +class ConvolutionDepthwiseInt8CPUKernel : public ConvolutionBaseCPUKernel { + public: + ConvolutionDepthwiseInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~ConvolutionDepthwiseInt8CPUKernel() override { + delete sliding; + free(packed_weight_); + free(packed_input_); + if (need_align_) { + free(packed_output_); + } + }; + + int Init() override; + int ReSize() override; + int Run() override; + + int InitWeightBias(); + int Execute(int task_id); + + private: + SlidingWindowParam *sliding; + int16_t *packed_weight_; + int16_t *packed_input_; + int8_t *packed_output_; + bool need_align_ = false; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_DEPTHWISE_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc new file mode 100644 index 00000000000..cbd652225c7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc @@ -0,0 +1,388 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/convolution_int8.h" +#include "src/runtime/kernel/arm/opclib/int8/conv_int8.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { +void ConvolutionInt8CPUKernel::CheckSupportOptimize() { + tile_num_ = 24; +#ifdef ENABLE_ARM32 + tile_num_ = 2; + support_optimize_ = false; +#endif + +#ifdef __aarch64__ + void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; + if (optimize_op_handler != nullptr) { + dlerror(); + *(reinterpret_cast(&gemm_func_)) = dlsym(optimize_op_handler, "IndirectGemmInt8_optimize_handler"); + auto dlopen_error = dlerror(); + if (dlopen_error != nullptr) { + MS_LOG(ERROR) << "load gemm func failed! " << dlopen_error << "."; + tile_num_ = 4; + support_optimize_ = false; + gemm_func_ = nullptr; + } else { + // do nothing + } + } else { + tile_num_ = 4; + support_optimize_ = false; + } +#endif +} + +int ConvolutionInt8CPUKernel::InitWeightBias() { + int kernel_h = conv_param_->kernel_h_; + int kernel_w = conv_param_->kernel_w_; + int in_channel = conv_param_->input_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int out_channel = conv_param_->output_channel_; + int oc4 = UP_DIV(out_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int plane_c4 = UP_DIV(kernel_plane, C4NUM); + int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * plane_c4 * C4NUM; + int32_t filter_zp = conv_param_->conv_quant_arg_.quant_args_[1][0].zp_; + int32_t input_zp = conv_param_->conv_quant_arg_.quant_args_[0][0].zp_; + + // init weight + auto origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + packed_weight_ = reinterpret_cast(malloc(pack_weight_size)); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "malloc packed_weight_ failed."; + return RET_ERROR; + } + memset(packed_weight_, 0, pack_weight_size); + int32_t *weight_sum = reinterpret_cast(malloc(sizeof(int32_t) * out_channel)); + for (int i = 0; i < out_channel; i++) weight_sum[i] = 0; + PackWeightInt8(origin_weight, conv_param_, packed_weight_, weight_sum); + + // init bias + bias_data_ = reinterpret_cast(malloc(oc4 * C4NUM * sizeof(int32_t))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "malloc bias_data_ failed."; + return RET_ERROR; + } + memset(bias_data_, 0, oc4 * C4NUM * sizeof(int32_t)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias, out_channel * sizeof(int32_t)); + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + auto *bias_data = reinterpret_cast(bias_data_); + int c4_kernel_plane_size = kernel_plane * ic4 * C4NUM; + for (int i = 0; i < out_channel; i++) { + bias_data[i] += filter_zp * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp; + } + free(weight_sum); + return RET_OK; +} + +int ConvolutionInt8CPUKernel::InitTmpBuffer() { + int tile_n = 4; + int output_count = conv_param_->output_h_ * conv_param_->output_w_; + int output_tile_count = UP_DIV(output_count, tile_n); + int in_channel = conv_param_->input_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_; + int plane_c4 = UP_DIV(kernel_plane, C4NUM); + int unit_size = plane_c4 * C4NUM * ic4 * C4NUM; + int packed_input_size = output_tile_count * tile_n * unit_size; + + packed_input_ = reinterpret_cast(malloc(conv_param_->input_batch_ * packed_input_size)); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "malloc packed_input_ failed."; + return RET_ERROR; + } + memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size); + + input_sum_ = reinterpret_cast(malloc(tile_n * thread_count_ * sizeof(int32_t))); + if (input_sum_ == nullptr) { + MS_LOG(ERROR) << "malloc input_sum_ failed."; + return RET_ERROR; + } + memset(input_sum_, 0, tile_n * thread_count_ * sizeof(int32_t)); + + size_t tmp_dst_size = thread_count_ * tile_n * conv_param_->output_channel_ * sizeof(int32_t); + tmp_dst_ = reinterpret_cast(malloc(tmp_dst_size)); + if (tmp_dst_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_dst_ failed."; + return RET_ERROR; + } + memset(tmp_dst_, 0, tmp_dst_size); + + tmp_out_ = reinterpret_cast(malloc(thread_count_ * tile_n * conv_param_->output_channel_)); + if (tmp_out_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_out_ failed."; + return RET_ERROR; + } + + size_t nhwc4_input_size = ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_; + nhwc4_input_ = malloc(nhwc4_input_size); + if (nhwc4_input_ == nullptr) { + MS_LOG(ERROR) << "malloc nhwc4 input failed."; + return RET_ERROR; + } + memset(nhwc4_input_, 0, nhwc4_input_size); + return RET_OK; +} + +int ConvolutionInt8CPUKernel::InitWeightBiasOpt() { + int kernel_h = conv_param_->kernel_h_; + int kernel_w = conv_param_->kernel_w_; + int in_channel = conv_param_->input_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int out_channel = conv_param_->output_channel_; + int oc4 = UP_DIV(out_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane; + int32_t filter_zp = conv_param_->conv_quant_arg_.quant_args_[1][0].zp_; + int32_t input_zp = conv_param_->conv_quant_arg_.quant_args_[0][0].zp_; + + // init weight + auto origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + packed_weight_ = reinterpret_cast(malloc(pack_weight_size)); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "malloc packed_weight_ failed."; + return RET_ERROR; + } + memset(packed_weight_, filter_zp, pack_weight_size); + int32_t *weight_sum = reinterpret_cast(malloc(sizeof(int32_t) * out_channel)); + for (int i = 0; i < out_channel; i++) weight_sum[i] = filter_zp * ic4 * C4NUM * kernel_plane; + PackWeightInt8Opt(origin_weight, conv_param_, packed_weight_, weight_sum); + + // init bias + bias_data_ = reinterpret_cast(malloc(oc4 * C4NUM * sizeof(int32_t))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "malloc bias_data_ failed."; + return RET_ERROR; + } + memset(bias_data_, 0, oc4 * C4NUM * sizeof(int32_t)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias, out_channel * sizeof(int32_t)); + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + auto *bias_data = reinterpret_cast(bias_data_); + int c4_kernel_plane_size = kernel_plane * ic4 * C4NUM; + for (int i = 0; i < out_channel; i++) { + bias_data[i] += filter_zp * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp; + } + free(weight_sum); + return RET_OK; +} + +int ConvolutionInt8CPUKernel::InitTmpBufferOpt() { + // todo + int tile_n = 24; + int output_count = conv_param_->output_h_ * conv_param_->output_w_; + int output_tile_count = UP_DIV(output_count, tile_n); + int in_channel = conv_param_->input_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_; + int unit_size = kernel_plane * ic4 * C4NUM; + int packed_input_size = output_tile_count * tile_n * unit_size; + + packed_input_ = reinterpret_cast(malloc(conv_param_->input_batch_ * packed_input_size)); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "malloc packed_input_ failed."; + return RET_ERROR; + } + memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size); + + input_sum_ = reinterpret_cast(malloc(tile_n * thread_count_ * sizeof(int32_t))); + if (input_sum_ == nullptr) { + MS_LOG(ERROR) << "malloc input_sum_ failed."; + return RET_ERROR; + } + memset(input_sum_, 0, tile_n * thread_count_ * sizeof(int32_t)); + + size_t tmp_dst_size = thread_count_ * tile_n * conv_param_->output_channel_ * sizeof(int32_t); + tmp_dst_ = reinterpret_cast(malloc(tmp_dst_size)); + if (tmp_dst_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_dst_ failed."; + return RET_ERROR; + } + memset(tmp_dst_, 0, tmp_dst_size); + + tmp_out_ = reinterpret_cast(malloc(thread_count_ * tile_n * conv_param_->output_channel_)); + if (tmp_out_ == nullptr) { + MS_LOG(ERROR) << "malloc tmp_out_ failed."; + return RET_ERROR; + } + + size_t nhwc4_input_size = ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_; + nhwc4_input_ = malloc(nhwc4_input_size); + if (nhwc4_input_ == nullptr) { + MS_LOG(ERROR) << "malloc nhwc4 input failed."; + return RET_ERROR; + } + memset(nhwc4_input_, 0, nhwc4_input_size); + return RET_OK; +} + +void ConvolutionInt8CPUKernel::ConfigInputOutput() { + auto output_tensor = outputs_.at(kOutputIndex); + output_tensor->SetFormat(schema::Format_NHWC); + auto input_tensor = inputs_.at(kInputIndex); + auto ret = CheckLayout(input_tensor); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Check layout failed."; + return; + } +} + +int ConvolutionInt8CPUKernel::Init() { + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + // config input output + ConfigInputOutput(); + CheckSupportOptimize(); + SetQuantParam(); + // init for opt + if (support_optimize_) { + ret = InitOpt(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Initialization for optimized int8 conv failed."; + return RET_ERROR; + } + return RET_OK; + } + + // init for situation that not support sdot + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + // init tmp input, output + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionInt8CPUKernel::InitOpt() { + auto ret = InitWeightBiasOpt(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init weight bias failed."; + return RET_ERROR; + } + // init tmp input, output + ret = InitTmpBufferOpt(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionInt8CPUKernel::ReSize() { + if (packed_input_ != nullptr) { + free(packed_input_); + } + if (input_sum_ != nullptr) { + free(input_sum_); + } + if (tmp_dst_ != nullptr) { + free(tmp_dst_); + } + if (tmp_out_ != nullptr) { + free(tmp_out_); + } + + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvolutionBase init failed."; + return RET_ERROR; + } + if (support_optimize_) { + ret = InitTmpBufferOpt(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer for opt failed."; + return RET_ERROR; + } + return RET_OK; + } + // init tmp input, output + ret = InitTmpBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init tmp buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionInt8CPUKernel::RunImpl(int task_id) { + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + if (support_optimize_) { + ConvInt8Opt(reinterpret_cast(nhwc4_input_), packed_input_, packed_weight_, + reinterpret_cast(bias_data_), tmp_dst_, tmp_out_, output_addr, input_sum_, task_id, + conv_param_, gemm_func_); + } else { + ConvInt8(reinterpret_cast(nhwc4_input_), packed_input_, packed_weight_, + reinterpret_cast(bias_data_), tmp_dst_, tmp_out_, output_addr, input_sum_, task_id, + conv_param_); + } + return RET_OK; +} + +int ConvolutionInt8Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto conv = reinterpret_cast(cdata); + auto error_code = conv->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Convolution Int8 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionInt8CPUKernel::Run() { + auto input_tensor = inputs_.at(kInputIndex); + auto ori_input_data = input_tensor->Data(); + int in_batch = conv_param_->input_batch_; + int in_h = conv_param_->input_h_; + int in_w = conv_param_->input_w_; + int in_channel = conv_param_->input_channel_; + convert_func_(ori_input_data, nhwc4_input_, in_batch, in_h * in_w, in_channel); + + int error_code = LiteBackendParallelLaunch(ConvolutionInt8Impl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv int8 error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.h new file mode 100644 index 00000000000..cfae7492ec2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/opclib/optimized_kernel.h" +#include "src/runtime/kernel/arm/opclib/int8/conv_int8.h" + +namespace mindspore::kernel { +class ConvolutionInt8CPUKernel : public ConvolutionBaseCPUKernel { + public: + ConvolutionInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~ConvolutionInt8CPUKernel() override { + if (packed_weight_ != nullptr) { + free(packed_weight_); + } + if (packed_input_ != nullptr) { + free(packed_input_); + } + if (input_sum_ != nullptr) { + free(input_sum_); + } + if (tmp_dst_ != nullptr) { + free(tmp_dst_); + } + if (tmp_out_ != nullptr) { + free(tmp_out_); + } + FreeQuantParam(); + }; + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + void CheckSupportOptimize(); + int InitOpt(); + int InitWeightBiasOpt(); + int InitTmpBufferOpt(); + int InitWeightBias(); + int InitTmpBuffer(); + void ConfigInputOutput(); + + private: + bool support_optimize_ = true; + int8_t *packed_weight_ = nullptr; + int8_t *packed_input_ = nullptr; + int32_t *input_sum_ = nullptr; + int32_t *tmp_dst_ = nullptr; + int8_t *tmp_out_ = nullptr; + GEMM_FUNC gemm_func_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc new file mode 100644 index 00000000000..4cc9a774584 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc @@ -0,0 +1,174 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/opclib/int8/conv_depthwise_int8.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +int DeconvolutionDepthwiseInt8CPUKernel::InitWeightBias() { + // init weight: int8 -> int16 + // o, h, w, i -> o/8, h, w, i, 8; o == group, i == 1 + auto origin_weight = reinterpret_cast(inputs_[kWeightIndex]->Data()); + int OC4 = UP_DIV(conv_param_->output_channel_, C4NUM); + int pack_weight_size = C4NUM * OC4 * conv_param_->kernel_h_ * conv_param_->kernel_w_; + packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(int16_t))); + memset(packed_weight_, 0, pack_weight_size * sizeof(int16_t)); + PackDepthwiseInt8Weight(origin_weight, packed_weight_, conv_param_); + + // init bias, add output zp + bias_data_ = reinterpret_cast(malloc(C4NUM * OC4 * sizeof(int32_t))); + memset(bias_data_, 0, C4NUM * OC4 * sizeof(int32_t)); + if (inputs_.size() == kInputSize2) { + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy(bias_data_, ori_bias, conv_param_->output_channel_ * sizeof(int32_t)); + } + return RET_OK; +} + +int DeconvolutionDepthwiseInt8CPUKernel::InitSlideParam() { + conv_param_->input_batch_ = outputs_.front()->shape().at(kNHWC_N); + conv_param_->input_h_ = outputs_.front()->shape().at(kNHWC_H); + conv_param_->input_w_ = outputs_.front()->shape().at(kNHWC_W); + conv_param_->input_channel_ = C4NUM; + conv_param_->output_batch_ = inputs_.front()->shape().at(kNHWC_N); + conv_param_->output_h_ = inputs_.front()->shape().at(kNHWC_H); + conv_param_->output_w_ = inputs_.front()->shape().at(kNHWC_W); + conv_param_->output_channel_ = inputs_.front()->shape().at(kNHWC_C); + + // init sliding window param + sliding = new SlidingWindowParam; + InitSlidingParam(sliding, conv_param_, C4NUM); + + sliding->in_h_step_ = conv_param_->input_w_ * C4NUM; + sliding->in_sh_step_ = conv_param_->input_w_ * C4NUM * conv_param_->stride_h_; // stride H + sliding->in_sw_step_ = C4NUM * conv_param_->stride_h_; // stride W + sliding->in_kh_step_ = conv_param_->input_w_ * C4NUM * conv_param_->dilation_h_; // kernel H + sliding->in_kw_step_ = C4NUM * conv_param_->dilation_w_; // kernel W + return RET_OK; +} + +int DeconvolutionDepthwiseInt8CPUKernel::Init() { + InitSlideParam(); + + // conv base init + ConvolutionBaseCPUKernel::Init(); + + // init quant param + ConvolutionBaseCPUKernel::SetQuantParam(); + + // init weight and bias + auto ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Deconv Depthwise int8 InitWeightBias error!"; + return ret; + } + + ret = ReSize(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Deconv Depthwise int8 ReSize error!"; + return ret; + } + return RET_OK; +} + +int DeconvolutionDepthwiseInt8CPUKernel::ReSize() { + // malloc packed input buffer + int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * + UP_DIV(conv_param_->input_channel_, 4); + packed_input_ = reinterpret_cast(malloc(pack_input_size * sizeof(int16_t))); + memset(packed_input_, 0, pack_input_size * sizeof(int16_t)); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + + if (conv_param_->input_channel_ % C4NUM != 0) { + need_align_ = true; + int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * + (conv_param_->output_channel_, C4NUM); + packed_output_ = reinterpret_cast(malloc(pack_output_size * sizeof(int8_t))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + memset(packed_output_, 0, pack_output_size * sizeof(int8_t)); + } + + // malloc tmp buffer for int32 output + output_buffer = + reinterpret_cast(malloc(conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * sizeof(int32_t))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + return RET_OK; +} + +int DeconvolutionDepthwiseInt8CPUKernel::Execute(int task_id) { + DeconvDwInt8(packed_output_, output_buffer, packed_input_, packed_weight_, reinterpret_cast(bias_data_), + conv_param_, sliding, task_id); + return RET_OK; +} + +int DeconvDwInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto deconv_dw = reinterpret_cast(cdata); + auto ret = deconv_dw->Execute(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DeconvolutionDepthwiseInt8Run error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int DeconvolutionDepthwiseInt8CPUKernel::Run() { + if (conv_param_->input_channel_ != conv_param_->output_channel_) { + MS_LOG(ERROR) << "Only support input channel equals output channel."; + return RET_ERROR; + } + + // pack input, assume input format: NHWC -> NHWC4 + auto input_tensor = inputs_.at(kInputIndex); + auto input_addr = reinterpret_cast(input_tensor->Data()); + PackDepthwiseInt8Input(input_addr, packed_input_, conv_param_); + + auto output_addr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + memset(output_addr, 0, outputs_.at(kOutputIndex)->ElementsNum() * sizeof(int8_t)); + if (!need_align_) { + packed_output_ = output_addr; + } + + auto ret = LiteBackendParallelLaunch(DeconvDwInt8Run, this, conv_param_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DeconvDwInt8Run error: error_code[" << ret << "]"; + return RET_ERROR; + } + + if (need_align_) { + PackNHWC4ToNHWCInt8(packed_output_, output_addr, conv_param_->output_batch_, + conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.h new file mode 100644 index 00000000000..a394839bcae --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DECONVOLUTION_DEPTHWISE_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DECONVOLUTION_DEPTHWISE_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/opclib/fp32/conv_depthwise.h" + +namespace mindspore::kernel { +class DeconvolutionDepthwiseInt8CPUKernel : public ConvolutionBaseCPUKernel { + public: + DeconvolutionDepthwiseInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~DeconvolutionDepthwiseInt8CPUKernel() override { + delete sliding; + free(packed_weight_); + free(packed_input_); + if (need_align_) { + free(packed_output_); + } + }; + + int Init() override; + int ReSize() override; + int Run() override; + + int InitSlideParam(); + int InitWeightBias(); + int Execute(int task_id); + + private: + SlidingWindowParam *sliding; + int16_t *packed_weight_; + int16_t *packed_input_; + int8_t *packed_output_; + int32_t *output_buffer; + bool need_align_ = false; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DECONVOLUTION_DEPTHWISE_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc new file mode 100644 index 00000000000..a3f10aad238 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc @@ -0,0 +1,220 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/deconvolution_int8.h" +#include "src/runtime/kernel/arm/opclib/quantization/fixed_point.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_MEMORY_FAILED; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_DeConv2D; + +namespace mindspore::kernel { +DeConvInt8CPUKernel::~DeConvInt8CPUKernel() { + if (weight_ptr_ != nullptr) { + free(weight_ptr_); + weight_ptr_ = nullptr; + } + if (tmp_buffer_ != nullptr) { + free(tmp_buffer_); + tmp_buffer_ = nullptr; + } + if (input_ptr_ != nullptr) { + free(input_ptr_); + input_ptr_ = nullptr; + } + if (tmp_output_ != nullptr) { + free(tmp_output_); + tmp_output_ = nullptr; + } + ConvolutionBaseCPUKernel::FreeQuantParam(); +} + +int DeConvInt8CPUKernel::ReSize() { return RET_OK; } + +int DeConvInt8CPUKernel::InitParam() { + fc_param_ = new MatMulParameter(); + fc_param_->row_ = conv_param_->input_h_ * conv_param_->input_w_; + fc_param_->deep_ = conv_param_->input_channel_; + fc_param_->col_ = conv_param_->output_channel_ * conv_param_->kernel_h_ * conv_param_->kernel_w_; + fc_param_->row_8_ = UP_ROUND(fc_param_->row_, C8NUM); + fc_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * conv_param_->kernel_h_ * conv_param_->kernel_w_; + + size_t oc8 = UP_DIV(conv_param_->output_channel_, C8NUM); + thread_count_ = MSMIN(opParameter->thread_num_, oc8); + thread_stride_ = UP_DIV(oc8, thread_count_) * C8NUM; + return RET_OK; +} + +int DeConvInt8CPUKernel::InitBiasWeight() { + if (inputs_.size() == 3) { + size_t size = UP_ROUND(conv_param_->output_channel_, C8NUM) * sizeof(int32_t); + bias_data_ = malloc(size); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "deconv int8 malloc bias_data_ error!"; + return RET_ERROR; + } + memset(bias_data_, 0, size); + memcpy(bias_data_, inputs_[0]->Data(), conv_param_->output_channel_ * sizeof(int32_t)); + } else { + bias_data_ = nullptr; + } + + /* weight: ichwoc(nhwc) -> oc8 * h * w * inc * 8 */ + size_t size = conv_param_->kernel_w_ * conv_param_->kernel_h_ * UP_ROUND(conv_param_->output_channel_, C8NUM) * + conv_param_->input_channel_ * sizeof(int8_t); + weight_ptr_ = reinterpret_cast(malloc(size)); + if (weight_ptr_ == nullptr) { + MS_LOG(ERROR) << "deconv int8 malloc weight_ptr_ error!"; + return RET_ERROR; + } + memset(weight_ptr_, 0, size); + PackNHWCToC8HWN8Int8(inputs_[1]->Data(), weight_ptr_, conv_param_->input_channel_, + conv_param_->kernel_h_ * conv_param_->kernel_w_, conv_param_->output_channel_); + return RET_OK; +} + +int DeConvInt8CPUKernel::InitData() { + int size = UP_ROUND(conv_param_->input_h_ * conv_param_->input_w_, C8NUM) * conv_param_->input_channel_; + input_ptr_ = reinterpret_cast(malloc(size * sizeof(int8_t))); + if (input_ptr_ == nullptr) { + return RET_MEMORY_FAILED; + } + memset(input_ptr_, 0, size * sizeof(int8_t)); + + size = UP_ROUND(conv_param_->input_h_ * conv_param_->input_w_, C8NUM) * + UP_ROUND(conv_param_->output_channel_, C8NUM) * conv_param_->kernel_w_ * conv_param_->kernel_h_; + tmp_buffer_ = reinterpret_cast(malloc(size * sizeof(int32_t))); + if (tmp_buffer_ == nullptr) { + return RET_MEMORY_FAILED; + } + + size = UP_ROUND(conv_param_->output_channel_, C8NUM) * conv_param_->output_h_ * conv_param_->output_w_; + tmp_output_ = reinterpret_cast(malloc(size * sizeof(int32_t))); + if (tmp_output_ == nullptr) { + return RET_MEMORY_FAILED; + } + return RET_OK; +} + +int DeConvInt8CPUKernel::Init() { + ConvolutionBaseCPUKernel::Init(); + int error_code = ConvolutionBaseCPUKernel::SetQuantParam(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv int8 SetQuantParam error!"; + return error_code; + } + + error_code = InitParam(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv int8 InitParam error!"; + return error_code; + } + + error_code = InitBiasWeight(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv int8 InitBiasWeight error!"; + return error_code; + } + + error_code = InitData(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv int8 InitData error!"; + return error_code; + } + return RET_OK; +} + +int DeConvInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto deconv = reinterpret_cast(cdata); + auto error_code = deconv->DoDeconv(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "DeConvInt8Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int DeConvInt8PostFuncRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto deconv = reinterpret_cast(cdata); + auto error_code = deconv->DoPostFunc(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "DeConvInt8PostFuncRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int DeConvInt8CPUKernel::DoDeconv(int task_id) { + int cur_oc = MSMIN(thread_stride_, UP_ROUND(conv_param_->output_channel_, C8NUM) - task_id * thread_stride_); + if (cur_oc <= 0) { + return RET_OK; + } + + int input_plane = conv_param_->input_h_ * conv_param_->input_w_; + int kernel_plane = conv_param_->kernel_w_ * conv_param_->kernel_h_; + + DeConvInt8(input_ptr_, weight_ptr_ + task_id * thread_stride_ * kernel_plane * conv_param_->input_channel_, + tmp_buffer_ + task_id * thread_stride_ * input_plane * kernel_plane, fc_param_->row_8_, + cur_oc * kernel_plane, fc_param_->deep_, conv_param_); + + return RET_OK; +} + +int DeConvInt8CPUKernel::DoPostFunc(int task_id) { + int input_plane = conv_param_->input_h_ * conv_param_->input_w_; + int kernel_plane = conv_param_->kernel_w_ * conv_param_->kernel_h_; + int output_plane = conv_param_->output_h_ * conv_param_->output_w_; + + int cur_oc = MSMIN(thread_stride_, conv_param_->output_channel_ - task_id * thread_stride_); + if (cur_oc <= 0) { + return RET_OK; + } + + DeConvPostInt8(tmp_buffer_ + task_id * thread_stride_ * input_plane * kernel_plane, + reinterpret_cast(bias_data_) + task_id * thread_stride_, + tmp_output_ + task_id * thread_stride_ * output_plane, output_ptr_ + task_id * thread_stride_, cur_oc, + conv_param_); + return RET_OK; +} + +int DeConvInt8CPUKernel::Run() { + int8_t *src_in = reinterpret_cast(inputs_[0]->Data()); + int8_t *src_out = reinterpret_cast(outputs_[0]->Data()); + + for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { + RowMajor2Col8MajorInt8(src_in + batch_index * fc_param_->row_ * conv_param_->input_channel_, input_ptr_, + fc_param_->row_, fc_param_->deep_); + output_ptr_ = src_out + batch_index * fc_param_->col_; + + int error_code = LiteBackendParallelLaunch(DeConvInt8Run, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv int8 run error! error_code[" << error_code << "]"; + return RET_ERROR; + } + error_code = LiteBackendParallelLaunch(DeConvInt8PostFuncRun, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv int8 post run error! error_code[" << error_code << "]"; + return RET_ERROR; + } + } + + return RET_OK; +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.h new file mode 100644 index 00000000000..ba73bafc528 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DECONVOLUTION_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DECONVOLUTION_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/arm/opclib/int8/deconv.h" +#include "src/runtime/kernel/arm/opclib/int8/matmul.h" +#include "src/runtime/kernel/arm/base/layout_transform.h" +#include "src/runtime/kernel/arm/base/convolution_base.h" +#include "src/runtime/kernel/arm/opclib/arithmetic_common.h" + +namespace mindspore::kernel { +class DeConvInt8CPUKernel : public ConvolutionBaseCPUKernel { + public: + DeConvInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~DeConvInt8CPUKernel() override; + + int ReSize() override; + int Init() override; + int Run() override; + + public: + int DoDeconv(int task_id); + int DoPostFunc(int task_id); + + private: + int InitData(); + int InitParam(); + int InitBiasWeight(); + + private: + MatMulParameter *fc_param_; + int8_t *weight_ptr_; + int8_t *input_ptr_; /* record c8 input*/ + int32_t *tmp_buffer_; /* record matmul result */ + int32_t *tmp_output_; /* record post c8 result */ + int8_t *output_ptr_; + size_t thread_count_; + size_t thread_stride_; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DECONVOLUTION_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc new file mode 100644 index 00000000000..8f05cf2b27a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/fullconnection_int8.h" +#include "src/runtime/kernel/arm/opclib/int8/matmul.h" +#include "src/runtime/kernel/arm/opclib/common_func.h" +#include "include/errorcode.h" + +using mindspore::lite::RET_MEMORY_FAILED; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +int FullconnectionInt8CPUKernel::Init() { + fc_param_->row_ = (inputs_[0]->shape())[0]; + fc_param_->col_ = (inputs_[1]->shape())[1]; + fc_param_->deep_ = (inputs_[1]->shape())[0]; + fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8); + fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8); + + a_c8_ptr_ = + reinterpret_cast(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t))); + memset(a_c8_ptr_, 0, fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t)); + b_r8_ptr_ = + reinterpret_cast(ctx_->allocator->Malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t))); + memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t)); + c_r8x8_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int))); + memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int)); + if (!a_c8_ptr_ || !b_r8_ptr_ || !c_r8x8_ptr_) { + return RET_MEMORY_FAILED; + } + + auto input_tensor = inputs_[0]; + auto params = input_tensor->GetQuantParams(); + MS_ASSERT(params.size() == 1); + quant_params_.input.zp_ = params.front().zeroPoint; + quant_params_.input.scale_ = params.front().scale; + auto weight_tensor = inputs_[1]; + params = weight_tensor->GetQuantParams(); + MS_ASSERT(params.size() == 1); + quant_params_.weight.zp_ = params.front().zeroPoint; + quant_params_.weight.scale_ = params.front().scale; + auto output_tensor = outputs_[0]; + params = output_tensor->GetQuantParams(); + MS_ASSERT(params.size() == 1); + quant_params_.output.zp_ = params.front().zeroPoint; + quant_params_.output.scale_ = params.front().scale; + + double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_; + QuantizeMultiplier(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.output_shift); + CalculateActivationRangeQuantized(fc_param_->maxf_, fc_param_->minf_, quant_params_.output.scale_, + quant_params_.output.zp_, &quant_params_.out_act_max, &quant_params_.out_act_min); + + return RET_OK; +} + +int FullconnectionInt8CPUKernel::ReSize() { return RET_OK; } + +int FullconnectionInt8CPUKernel::Run() { + auto a_ptr = reinterpret_cast(inputs_.at(0)->Data()); + auto b_ptr = reinterpret_cast(inputs_.at(1)->Data()); + auto bias_ptr = reinterpret_cast(inputs_.at(2)->Data()); + auto output_ptr = reinterpret_cast(outputs_.at(0)->Data()); + auto &p = quant_params_; + + // rows*depth -> rows*depth, col_8 major + RowMajor2Col8MajorInt8(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_); + // cols*depth -> cols*depth, col_8 major == depth*cols, row_8 major + RowMajor2Col8MajorInt8(b_ptr, b_r8_ptr_, fc_param_->col_, fc_param_->deep_); + MatMulInt8(a_c8_ptr_, b_r8_ptr_, c_r8x8_ptr_, fc_param_->row_8_, fc_param_->col_8_, fc_param_->deep_, p.input.zp_, + p.weight.zp_); + PostFuncInt8(c_r8x8_ptr_, bias_ptr, output_ptr, fc_param_->col_, fc_param_->row_, fc_param_->col_8_, + fc_param_->row_8_, p.quant_multiplier, p.output_shift, p.output.zp_, p.out_act_min, p.out_act_max); + + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h new file mode 100644 index 00000000000..e953d4e6383 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_FULLCONNECTION_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_FULLCONNECTION_INT8_H_ + +#include +#include "include/context.h" +#include "src/runtime/kernel/arm/opclib/quantization/quantize.h" +#include "src/runtime/kernel/arm/base/fullconnection_base.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class FullconnectionInt8CPUKernel : public FullconnectionBaseCPUKernel { + public: + FullconnectionInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : FullconnectionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~FullconnectionInt8CPUKernel() override { + free(a_c8_ptr_); + free(b_r8_ptr_); + free(c_r8x8_ptr_); + } + + int Init() override; + int ReSize() override; + int Run() override; + + private: + FcQuantArg quant_params_; + int8_t *a_c8_ptr_; + int8_t *b_r8_ptr_; + int *c_r8x8_ptr_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_FULLCONNECTION_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc new file mode 100644 index 00000000000..00e2fe5d518 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/mul_int8.h" +#include +#include +#include "src/runtime/kernel/arm/opclib/arithmetic_common.h" +#include "src/runtime/kernel/arm/opclib/int8/mul_int8.h" +#include "src/runtime/runtime_api.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Mul; + +namespace mindspore::kernel { +int MulInt8CPUKernel::Init() { + lite::tensor::Tensor *input0 = inputs_.at(0); + lite::tensor::Tensor *input1 = inputs_.at(1); + lite::tensor::Tensor *output = outputs_.at(0); + MS_ASSERT(input0); + MS_ASSERT(input1); + MS_ASSERT(output); + + para_.mul_quant_arg_.in_quant_args_[0].scale_ = input0->GetQuantParams().front().scale; + para_.mul_quant_arg_.in_quant_args_[0].zp_ = input0->GetQuantParams().front().zeroPoint * -1; + para_.mul_quant_arg_.in_quant_args_[1].scale_ = input1->GetQuantParams().front().scale; + para_.mul_quant_arg_.in_quant_args_[1].zp_ = input1->GetQuantParams().front().zeroPoint * -1; + para_.mul_quant_arg_.out_quant_arg_.scale_ = output->GetQuantParams().front().scale; + para_.mul_quant_arg_.out_quant_arg_.zp_ = output->GetQuantParams().front().zeroPoint; + para_.mul_quant_arg_.output_activation_max_ = std::numeric_limits::max(); + para_.mul_quant_arg_.output_activation_min_ = std::numeric_limits::min(); + + const double real_multiplier = + (para_.mul_quant_arg_.in_quant_args_[0].scale_ * para_.mul_quant_arg_.in_quant_args_[1].scale_) / + para_.mul_quant_arg_.out_quant_arg_.scale_; + + int right_shift = 0; + QuantizeMultiplierSmallerThanOne(real_multiplier, ¶_.mul_quant_arg_.output_multiplier_, &right_shift); + + para_.mul_quant_arg_.shift_left_ = right_shift < 0 ? -right_shift : 0; + para_.mul_quant_arg_.shift_right_ = right_shift > 0 ? right_shift : 0; + + return RET_OK; +} + +int MulInt8CPUKernel::ReSize() { return RET_OK; } + +int MulInt8CPUKernel::Run() { + input0_data_ = static_cast(inputs_.at(0)->Data()); + input1_data_ = static_cast(inputs_.at(1)->Data()); + output_data_ = static_cast(outputs_.at(0)->Data()); + + elements_num_ = inputs_.at(0)->ElementsNum(); + count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_; + + if (inputs_.at(0)->ElementsNum() != inputs_.at(1)->ElementsNum()) { + input0_data_ = static_cast(ctx_->allocator->Malloc(outputs_.at(0)->Size())); + input1_data_ = static_cast(ctx_->allocator->Malloc(outputs_.at(0)->Size())); + + ArithmeticParameter tile_para = {0}; + tile_para.ndim_ = outputs_.at(0)->shape().size(); + for (size_t i = 0; i < tile_para.ndim_; i++) { + tile_para.in_shape0_[i] = inputs_.at(0)->DimensionSize(i); + tile_para.in_shape1_[i] = inputs_.at(1)->DimensionSize(i); + tile_para.out_shape_[i] = outputs_.at(0)->DimensionSize(i); + } + TileDimensionsInt8(static_cast(inputs_.at(0)->Data()), static_cast(inputs_.at(1)->Data()), + input0_data_, input1_data_, &tile_para); + auto ret = LiteBackendParallelLaunch(MulInt8Run, this, thread_count_); + ctx_->allocator->Free(input0_data_); + ctx_->allocator->Free(input1_data_); + return ret; + } + + auto ret = LiteBackendParallelLaunch(MulInt8Run, this, thread_count_); + return ret; +} + +int MulInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto mul = reinterpret_cast(cdata); + mul->DoExecute(task_id); + return lite::RET_OK; +} + +int MulInt8CPUKernel::DoExecute(int tId) { + int64_t real_dst_count = MSMIN(elements_num_ - tId * count_unit_, count_unit_); + int8_t *cur_input0_data = input0_data_ + tId * count_unit_; + int8_t *cur_input1_data = input1_data_ + tId * count_unit_; + int8_t *cur_output_data = output_data_ + tId * count_unit_; + + Mul(cur_input0_data, cur_input1_data, cur_output_data, real_dst_count, para_.mul_quant_arg_); + return lite::RET_OK; +} + +kernel::LiteKernel *CpuMulInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, const KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Mul); + auto *kernel = new (std::nothrow) MulInt8CPUKernel(opParameter, inputs, outputs, ctx); + + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.h new file mode 100644 index 00000000000..424a70c79ca --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MUL_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MUL_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/mul_parameter.h" +#include "src/runtime/runtime_api.h" + +namespace mindspore::kernel { +class MulInt8CPUKernel : public LiteKernel { + public: + explicit MulInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx_->threadNum) {} + ~MulInt8CPUKernel() override {}; + + int Init() override; + int ReSize() override; + int Run() override; + int DoExecute(int tId); + + private: + const lite::Context *ctx_; + MulParameter para_; + int thread_count_; + int64_t elements_num_; + int64_t count_unit_; + int8_t *input0_data_ = nullptr; + int8_t *input1_data_ = nullptr; + int8_t *output_data_ = nullptr; +}; + +int MulInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata); +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MUL_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc new file mode 100644 index 00000000000..f52e8aebdbd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/pooling_int8.h" +#include "src/runtime/kernel/arm/opclib/int8/pooling_int8.h" +#include "src/runtime/kernel/arm/opclib/fp32/cast.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +int PoolingInt8CPUKernel::Init() { + auto ret = PoolingBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PoolingBase Init failed."; + return RET_ERROR; + } + ret = SetQuantParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set pooling quant param failed."; + return RET_ERROR; + } + return RET_OK; +} + +int PoolingInt8CPUKernel::ReSize() { + FreeQuantParam(); + auto ret = PoolingBaseCPUKernel::Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PoolingBase Init failed."; + return RET_ERROR; + } + SetQuantParam(); + ret = SetQuantParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Set pooling quant param failed."; + return RET_ERROR; + } + return RET_OK; +} + +int PoolingInt8CPUKernel::RunImpl(int task_id) { + auto input_data = reinterpret_cast(inputs_.at(kInputIndex)->Data()); + auto output_data = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + if (pooling_param_->max_pooling_) { + MaxPoolingInt8(input_data, output_data, pooling_param_, task_id); + } else { + AvgPoolingInt8(input_data, output_data, pooling_param_, task_id); + } + return RET_OK; +} + +int PoolingInt8Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { + auto pooling = reinterpret_cast(cdata); + auto error_code = pooling->RunImpl(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "PoolingInt8 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int PoolingInt8CPUKernel::Run() { + int error_code = LiteBackendParallelLaunch(PoolingInt8Impl, this, thread_count_); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "poolingInt8 error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.h new file mode 100644 index 00000000000..367c10f59c6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_POOLING_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_POOLING_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "ir/anf.h" +#include "include/context.h" +#include "src/runtime/kernel/arm/base/pooling_base.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class PoolingInt8CPUKernel : public PoolingBaseCPUKernel { + public: + PoolingInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : PoolingBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~PoolingInt8CPUKernel() { FreeQuantParam(); } + + int Init() override; + int ReSize() override; + int Run() override; + int RunImpl(int task_id); + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_POOLING_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.cc new file mode 100644 index 00000000000..52a46f9381c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/reshape_int8.h" +#include "src/runtime/kernel/arm/opclib/int8/reshape_int8.h" +#include "schema/model_generated.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { + +int ReshapeInt8CPUKernel::Init() { + ReshapeBaseCPUKernel::Init(); + auto *input_tensor = inputs_.at(kInputIndex); + auto in_quant_args = input_tensor->GetQuantParams(); + in_quant_arg_.scale_ = in_quant_args.front().scale; + in_quant_arg_.zp_ = in_quant_args.front().zeroPoint; + + auto *out_tensor = outputs_.at(kOutputIndex); + auto out_quant_args = out_tensor->GetQuantParams(); + out_quant_arg_.scale_ = out_quant_args.front().scale; + out_quant_arg_.zp_ = out_quant_args.front().zeroPoint; + return RET_OK; +} + +int ReshapeInt8CPUKernel::ReSize() { return 0; } + +int ReshapeInt8CPUKernel::Run() { + MS_ASSERT(inputs_.size() == 1); + MS_ASSERT(outputs_.size() == 1); + auto input_type = inputs_[kInputIndex]->data_type(); + auto input_num = inputs_[kInputIndex]->ElementsNum(); + auto output_num = outputs_.at(kOutputIndex)->ElementsNum(); + MS_ASSERT(input_num == output_num); + int8_t *input_ptr = reinterpret_cast(inputs_.at(kInputIndex)->Data()); + int8_t *output_ptr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); + if (input_type == kNumberTypeUInt8) { + auto *input_tmp = reinterpret_cast(inputs_.at(kInputIndex)->Data()); + for (size_t i = 0; i < input_num; i++) { + input_ptr[i] = (int8_t)(input_tmp[i] - 128); + } + in_quant_arg_.zp_ -= 128; + out_quant_arg_.zp_ -= 128; + } + + size_t data_size = inputs_.at(kInputIndex)->Size(); + Reshape(input_ptr, output_ptr, data_size, input_num, in_quant_arg_, out_quant_arg_); + + auto output_type = outputs_[kOutputIndex]->data_type(); + if (output_type == kNumberTypeUInt8) { + for (size_t i = 0; i < output_num; i++) { + output_ptr[i] = (uint8_t)(output_ptr[i] + 128); + } + } + return RET_OK; +} +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.h new file mode 100644 index 00000000000..cb1065a4c00 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_RESHAPE_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_RESHAPE_INT8_H_ + +#include +#include "src/lite_kernel.h" + +#include "include/context.h" +#include "src/runtime/kernel/arm/base/reshape_base.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class ReshapeInt8CPUKernel : public ReshapeBaseCPUKernel { + public: + ReshapeInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx) + : ReshapeBaseCPUKernel(parameter, inputs, outputs, ctx) {} + ~ReshapeInt8CPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + QuantArg in_quant_arg_; + QuantArg out_quant_arg_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_RESHAPE_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/arm/opclib/CMakeLists.txt new file mode 100644 index 00000000000..9895fd458f4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/CMakeLists.txt @@ -0,0 +1,37 @@ +project(optimize) + +set(OPTIMIZED_OP_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +set(LITE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../) +include_directories(OPTIMIZED_OP_DIR) + +########################### optimized files ########################### +set(FP16_ASSEMBLY +# ${OPTIMIZED_OP_DIR}/assembly/arm64/IndirectGemmFp16_16x8.s + ) + +file(GLOB_RECURSE OPTIMIZED_INT8_ASSEMBLY + ${OPTIMIZED_OP_DIR}/assembly/opt/*.S + ) + +########################### share library build ######################## +set(OPTIMIZED_OPS "opt_op_handler.c") + +set_property(SOURCE ${OPTIMIZED_INT8_ASSEMBLY} PROPERTY LANGUAGE C) +list(APPEND OPTIMIZED_OPS ${OPTIMIZED_INT8_ASSEMBLY} ${FP16_ASSEMBLY}) + +if (PLATFORM_ARM64) + string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16") + add_library(optimize SHARED ${OPTIMIZED_OPS}) + set_target_properties(optimize PROPERTIES CLEAN_DIRECT_OUTPUT 1) + + add_custom_command(TARGET optimize POST_BUILD + COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip + ${LITE_DIR}/build/src/runtime/kernel/arm/opclib/liboptimize.so) + + add_custom_command(TARGET optimize POST_BUILD + COMMAND rm -rf ${LITE_DIR}/output/lib/liboptimize.so + COMMAND mkdir -pv ${LITE_DIR}/output/lib + COMMAND cp ${LITE_DIR}/build/src/runtime/kernel/arm/opclib/liboptimize.so ${LITE_DIR}/output/lib) +endif () diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/add_int8.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/add_int8.cc new file mode 100644 index 00000000000..3b6d115c90e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/add_int8.cc @@ -0,0 +1,123 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/add_int8.h" +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/opclib/quantization/fixed_point.h" + +#ifdef ENABLE_NEON +int16x8_t LoadAndAddOffset(int8_t *data, int index, int offset) { + int8x8_t input_s8 = vld1_s8(data + index); + int16x8_t input_s16 = vmovl_s8(input_s8); + return vaddq_s16(input_s16, vdupq_n_s16(offset)); +} + +int32x4_t ClacScaledInput(int32x4_t input, int32x4_t left_shift_result_vec, int32x4_t input_multiplier_vec, + int32x4_t right_shift_vec) { + int32x4_t shifted_input = vmulq_s32(input, left_shift_result_vec); + shifted_input = vqrdmulhq_s32(shifted_input, input_multiplier_vec); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(shifted_input, right_shift_vec), 31); + return vrshlq_s32(vqaddq_s32(shifted_input, fixup), right_shift_vec); +} + +int16x4_t ClacSumHalfWord(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec, + int32x4_t output_multiplier_vec, AddQuantParameter *para) { + int32x4_t raw_sum = vaddq_s32(scaled_input0, scaled_input1); + + raw_sum = RoundingDivideByPOTInt32x4(vqrdmulhq_s32(vmulq_s32(raw_sum, left_shift_out_vec), output_multiplier_vec), + para->right_shift_out_); + raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(para->output_offset_)); + raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(para->output_activation_min_)); + raw_sum = vminq_s32(raw_sum, vdupq_n_s32(para->output_activation_max_)); + return vqmovn_s32(raw_sum); +} + +void AddInt8NEON(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + AddQuantParameter *para, int *index) { + int32x4_t left_shift_result0_vec = vdupq_n_s32(para->left_shift_result0_); + int32x4_t left_shift_result1_vec = vdupq_n_s32(para->left_shift_result1_); + int32x4_t input0_multiplier_vec = vdupq_n_s32(para->input0_multiplier_); + int32x4_t input1_multiplier_vec = vdupq_n_s32(para->input1_multiplier_); + int32x4_t output_multiplier_vec = vdupq_n_s32(para->output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32((1 << para->left_shift_out_)); + int32x4_t right_shift0_vec = vdupq_n_s32(-para->right_shift0_); + int32x4_t right_shift1_vec = vdupq_n_s32(-para->right_shift1_); + + for (; (*index) <= real_dst_count - 8; (*index) += 8) { + int16x8_t input0_val = LoadAndAddOffset(input0_data, *index, para->input0_offset_); + int16x8_t input1_val = LoadAndAddOffset(input1_data, *index, para->input1_offset_); + + int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); + int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); + int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); + int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); + + int32x4_t scaled_input0_low = + ClacScaledInput(input0_low, left_shift_result0_vec, input0_multiplier_vec, right_shift0_vec); + int32x4_t scaled_input0_high = + ClacScaledInput(input0_high, left_shift_result0_vec, input0_multiplier_vec, right_shift0_vec); + int32x4_t scaled_input1_low = + ClacScaledInput(input1_low, left_shift_result1_vec, input1_multiplier_vec, right_shift1_vec); + int32x4_t scaled_input1_high = + ClacScaledInput(input1_high, left_shift_result1_vec, input1_multiplier_vec, right_shift1_vec); + + int16x4_t sum_low = + ClacSumHalfWord(scaled_input0_low, scaled_input1_low, left_shift_out_vec, output_multiplier_vec, para); + int16x4_t sum_high = + ClacSumHalfWord(scaled_input0_high, scaled_input1_high, left_shift_out_vec, output_multiplier_vec, para); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(output_data, res_u8_n0); + } +} +#endif + +void AddInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + AddQuantParameter *para) { + int index = 0; +#ifdef ENABLE_NEON + AddInt8NEON(input0_data, input1_data, output_data, real_dst_count, para, &index); +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = para->input0_offset_ + input0_data[index]; + const int32_t input1_val = para->input1_offset_ + input1_data[index]; + const int32_t shifted_input0_val = input0_val * para->left_shift_result0_; + const int32_t shifted_input1_val = input1_val * para->left_shift_result1_; + const int32_t scaled_input0_val = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(shifted_input0_val, para->input0_multiplier_), para->right_shift0_); + const int32_t scaled_input1_val = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(shifted_input1_val, para->input1_multiplier_), para->right_shift1_); + + const int32_t raw_sum = scaled_input0_val + scaled_input1_val; + const int32_t raw_output = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(raw_sum * (1 << (unsigned int)para->left_shift_out_), + para->output_multiplier_), + para->right_shift_out_) + + para->output_offset_; + if (raw_output > para->output_activation_max_) { + output_data[index] = para->output_activation_max_; + } else if (raw_output < para->output_activation_min_) { + output_data[index] = para->output_activation_min_; + } else { + output_data[index] = (int8_t)raw_output; + } + } + return; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/add_int8.h b/mindspore/lite/src/runtime/kernel/arm/opclib/add_int8.h new file mode 100644 index 00000000000..b947e76285c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/add_int8.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ADD_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ADD_INT8_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct AddQuantParameter { + int input0_offset_; + int input1_offset_; + int output_offset_; + float input0_scale_; + float input1_scale_; + float output_scale_; + int input0_multiplier_; + int input1_multiplier_; + int output_multiplier_; + int input0_shift_; + int input1_shift_; + int output_shift_; + int output_activation_min_; + int output_activation_max_; + int left_shift_result0_; + int left_shift_result1_; + int right_shift0_; + int right_shift1_; + int left_shift_out_; + int right_shift_out_; +}; + +void AddInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + AddQuantParameter *para); + +#ifdef ENABLE_NEON +#include +int16x8_t LoadAndAddOffset(int8_t *data, int index, int offset); +#endif + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ADD_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_common.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_common.cc new file mode 100644 index 00000000000..a39b8942321 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_common.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/arithmetic_common.h" + +void TileOneDimension(float *inData, float *outData, int dim, size_t ndim, int *inShape, int *inStrides, + int *outStrides, int *multiple) { + int srcDimSize = inShape[dim]; + if (dim == ndim - 1) { + for (int i = 0; i < multiple[dim]; i++) { + memcpy(outData, inData, srcDimSize * sizeof(float)); + outData += srcDimSize; + } + return; + } + for (size_t i = 0; i < srcDimSize; i++) { + for (size_t j = 0; j < multiple[dim]; j++) { + TileOneDimension(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, ndim, + inShape, inStrides, outStrides, multiple); + } + } +} + +void TileOneDimensionUint8(uint8_t *inData, uint8_t *outData, int dim, size_t ndim, int *inShape, int *inStrides, + int *outStrides, int *multiple) { + int srcDimSize = inShape[dim]; + if (dim == ndim - 1) { + for (int i = 0; i < multiple[dim]; i++) { + memcpy(outData, inData, srcDimSize * sizeof(uint8_t)); + outData += srcDimSize; + } + return; + } + for (size_t i = 0; i < srcDimSize; i++) { + for (size_t j = 0; j < multiple[dim]; j++) { + TileOneDimensionUint8(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, + ndim, inShape, inStrides, outStrides, multiple); + } + } +} + +void ComputeStrides(int *shape, int *strides, int ndim) { + int stride = 1; + for (int i = ndim - 1; i >= 0; i--) { + strides[i] = stride; + stride *= shape[i]; + } +} + +void CalcMultiplesAndStrides(ArithmeticParameter *param) { + for (auto i = 0; i < param->ndim_; i++) { + param->multiples0_[i] = param->out_shape_[i] / param->in_shape0_[i]; + param->multiples1_[i] = param->out_shape_[i] / param->in_shape1_[i]; + } + // cal strides + ComputeStrides(param->in_shape0_, param->in_strides0_, param->ndim_); + ComputeStrides(param->in_shape1_, param->in_strides1_, param->ndim_); + ComputeStrides(param->out_shape_, param->out_strides_, param->ndim_); +} + +void TileDimensions(float *data0, float *data1, float *tile_data0, float *tile_data1, ArithmeticParameter *param) { + CalcMultiplesAndStrides(param); + TileOneDimension(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_, + param->multiples0_); + TileOneDimension(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_, + param->multiples1_); +} + +void TileDimensionsUint8(uint8_t *data0, uint8_t *data1, uint8_t *tile_data0, uint8_t *tile_data1, + ArithmeticParameter *param) { + CalcMultiplesAndStrides(param); + TileOneDimensionUint8(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_, + param->multiples0_); + TileOneDimensionUint8(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_, + param->multiples1_); +} + +void TileDimensionsInt8(int8_t *data0, int8_t *data1, int8_t *tile_data0, int8_t *tile_data1, + ArithmeticParameter *param) { + CalcMultiplesAndStrides(param); + TileOneDimensionUint8((uint8_t *)(data0), (uint8_t *)(tile_data0), 0, param->ndim_, + param->in_shape0_, param->in_strides0_, param->out_strides_, param->multiples0_); + TileOneDimensionUint8((uint8_t *)(data1), (uint8_t *)(tile_data1), 0, param->ndim_, + param->in_shape1_, param->in_strides1_, param->out_strides_, param->multiples1_); +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_common.h b/mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_common.h new file mode 100644 index 00000000000..15d132c38fe --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_common.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARITHMETIC_COMMON_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARITHMETIC_COMMON_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/arithmetic_common.h" + +struct ArithmeticParameter { + OpParameter op_parameter; + bool broadcasting_; + size_t ndim_; + int in_shape0_[5]; + int in_shape1_[5]; + int out_shape_[5]; + + int in_strides0_[5]; + int in_strides1_[5]; + int out_strides_[5]; + + int multiples0_[5]; + int multiples1_[5]; +}; +void TileOneDimension(float *inData, float *outData, int dim, size_t ndim, int *inShape, int *inStrides, + int *outStrides, int *multiple); +void ComputeStrides(int *shape, int *strides, int ndim); + +void TileDimensions(float *data0, float *data1, float *tile_data0, float *tile_data1, ArithmeticParameter *param); +void TileDimensionsUint8(uint8_t *data0, uint8_t *data1, uint8_t *tile_data0, uint8_t *tile_data1, + ArithmeticParameter *param); +void TileDimensionsInt8(int8_t *data0, int8_t *data1, int8_t *tile_data0, int8_t *tile_data1, + ArithmeticParameter *param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARITHMETIC_COMMON_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_parameter.h b/mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_parameter.h new file mode 100644 index 00000000000..af48c7c5c83 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARTITHMETIC_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARTITHMETIC_PARAMETER_H_ + +#include "src/runtime/kernel/arm/opclib/op_attribute.h" + + + + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARTITHMETIC_PARAMETER_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm32/IndirectGemmFp32_8x4.S b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm32/IndirectGemmFp32_8x4.S new file mode 100644 index 00000000000..4e4e8e27b9a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm32/IndirectGemmFp32_8x4.S @@ -0,0 +1,294 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global IndirectGemmFp32_8x4 +#ifndef __APPLE__ +.type IndirectGemmFp32_8x4, %function +#endif + +// void IndirectGemmFp32_8x4(float *output, float *input, float *weight, float *bias, +// size_t kSize, size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, size_t relu6); +// r0: output, r1: input, r2: weight, r3: bias, r4: kSize, r5: ic4, r6: oc, r7: offset +// r8:mode, r10: writeMode, x10: relu, r10:relu6 +// mode = 0 for general convolution, where one conv unit is a row +// mode = 1 for winograd/common gemm, where the total channels of one input is a row +IndirectGemmFp32_8x4: + + .macro INIT_BIAS + veor q10, q10, q10 + cbz x3, InitBias + vld1.32 q10, [x3] + InitBias: + vmov q11, q10 + vmov q12, q10 + vmov q13, q10 + vmov q14, q10 + vmov q15, q10 + .endm + + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // r19 ~ r29 should be also preserved + // whereas our coding style do not permit such amount of parameters + push {r4-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #160 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + + cbnz r8, LoopOc + // step is one for common convolution, where ic8 should multiply by kernel size + // step is (a+b-1) for F(a,b) in winograd + mul r5, r4, r5 + mov r4, #1 + + LoopOc: + mov r8, r4 + mov r12, r1 + + LoopKsize: + + mov r11, r0 + INIT_BIAS + + // load input for output 1-2 + vld1.32 {q0, q1, q2, q3}, [x12]! + // load weight + vld1.32 {q4, q5}, [x2]! + // step for output 1-2 + vmul.f32 q8, q4, d0[0] + vmul.f32 q9, q4, d2[0] + vmla.f32 q8, q5, d0[1] + vmla.f32 q9, q5, d2[1] + vld1.32 {q6, q7}, [x2]! + + subs x10, x5, #1 + beq LoopIcEnd + + LoopIc: + vmla.f32 q8, q6, d1[0] + vmla.f32 q9, q6, d3[0] + vmla.f32 q8, q7, d1[1] + vmla.f32 q9, q7, d3[1] + vmla.f32 q10, q4, d4[0] + vmla.f32 q11, q4, d6[0] + vmla.f32 q10, q5, d4[1] + vmla.f32 q11, q5, d6[1] + vld1.s32 {q0, q1}, [r12]! + vmla.f32 q10, q6, d5[0] + vmla.f32 q11, q6, d7[0] + vmla.f32 q10, q7, d5[1] + vmla.f32 q11, q7, d7[1] + vld1.s32 {q2, q3}, [r12]! + vmla.f32 q12, q4, d0[0] + vmla.f32 q13, q4, d2[0] + vmla.f32 q12, q5, d0[1] + vmla.f32 q13, q5, d2[1] + vmla.f32 q14, q4, d4[0] + vmla.f32 q15, q4, d6[0] + vmla.f32 q14, q5, d4[1] + vmla.f32 q15, q5, d6[1] + vld1.s32 {q4, q5}, [r2]! + vmla.f32 q12, q6, d1[0] + vmla.f32 q13, q6, d3[0] + vmla.f32 q12, q7, d1[1] + vmla.f32 q13, q7, d3[1] + vld1.s32 {q0, q1}, [r12]! + vmla.f32 q14, q6, d5[0] + vmla.f32 q15, q6, d7[0] + vmla.f32 q14, q7, d5[1] + vmla.f32 q15, q7, d7[1] + vld1.s32 {q6, q7}, [r2]! + vmla.f32 q8, q4, d0[0] + vmla.f32 q9, q4, d2[0] + vmla.f32 q8, q5, d0[1] + vmla.f32 q9, q5, d2[1] + vld1.s32 {q2, q3}, [r12]! + + subs r10, r10, #1 + bne LoopIc + + LoopIcEnd: + vmla.f32 q8, q6, d1[0] + vmla.f32 q9, q6, d3[0] + vmla.f32 q8, q7, d1[1] + vmla.f32 q9, q7, d3[1] + vmla.f32 q10, q4, d4[0] + vmla.f32 q11, q4, d6[0] + vmla.f32 q10, q5, d4[1] + vmla.f32 q11, q5, d6[1] + vld1.s32 {q0, q1}, [r12]! + vmla.f32 q10, q6, d5[0] + vmla.f32 q11, q6, d7[0] + vmla.f32 q10, q7, d5[1] + vmla.f32 q11, q7, d7[1] + vld1.s32 {q2, q3}, [r12]! + vmla.f32 q12, q4, d0[0] + vmla.f32 q13, q4, d2[0] + vmla.f32 q12, q5, d0[1] + vmla.f32 q13, q5, d2[1] + vmla.f32 q14, q4, d4[0] + vmla.f32 q15, q4, d6[0] + vmla.f32 q14, q5, d4[1] + vmla.f32 q15, q5, d6[1] + vmla.f32 q12, q6, d1[0] + vmla.f32 q13, q6, d3[0] + vmla.f32 q12, q7, d1[1] + vmla.f32 q13, q7, d3[1] + vmla.f32 q14, q6, d5[0] + vmla.f32 q15, q6, d7[0] + vmla.f32 q14, q7, d5[1] + vmla.f32 q15, q7, d7[1] + + ldr r10, [sp, #28] + cbnz r10, Relu6 + ldr r10, [sp, #24] + cbnz x10, Relu + b WriteStart + Relu6: + vmov.i32 q14, #6 + vcvt.f32.s32 q14, q14 + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmin.f32 q2, q2, q14 + vmin.f32 q3, q3, q14 + vmin.f32 q4, q4, q14 + vmin.f32 q5, q5, q14 + vmin.f32 q6, q6, q14 + vmin.f32 q7, q15, q14 + Relu: + veor q7, q7, q7 + vmax.f32 q0, q8, q7 + vmax.f32 q1, q9, q7 + vmax.f32 q2, q10, q7 + vmax.f32 q3, q11, q7 + vmax.f32 q4, q12, q7 + vmax.f32 q5, q13, q7 + vmax.f32 q6, q14, q7 + vmax.f32 q15, q15, q7 + + WriteStart: + ldr r10, [sp, #20] + cbnz x10, WriteC4 + cmp r6, #1 + beq Write1 + cmp r6, #2 + beq Write2 + cmp r6, #3 + beq Write3 + b Write4 + Write1: + str s0, [r11] + add r11, r11, x7 + str s4, [r11] + add r11, r11, x7 + str s8, [r11] + add r11, r11, x7 + str s12, [r11] + add r11, r11, x7 + str s16, [r11] + add r11, r11, x7 + str s20, [r11] + add r11, r11, x7 + str s24, [r11] + add r11, r11, x7 + str s28, [r11] + add r0, r0, #4 + b WriteEnd + Write2: + str d0, [r11] + add r11, r11, x7 + str d2, [r11] + add r11, r11, x7 + str d4, [r11] + add r11, r11, x7 + str d6, [r11] + add r11, r11, x7 + str d8, [r11] + add r11, r11, x7 + str d10, [r11] + add r11, r11, x7 + str d12, [r11] + add r11, r11, x7 + str d14, [r11] + add r0, r0, #8 + b WriteEnd + Write3: + add r12, r11, #8 + str d0, [r11] + add r11, r11, x7 + str s2, [r12] + add r12, r12, r7 + str d2, [r11] + add r11, r11, x7 + str s6, [r12] + add r12, r12, r7 + str d4, [r11] + add r11, r11, x7 + str s10, [r12] + add r12, r12, r7 + str d6, [r11] + add r11, r11, x7 + str s14, [r12] + add r12, r12, r7 + str d8, [r11] + add r11, r11, x7 + str s18, [r12] + add r12, r12, r7 + str d10, [r11] + add r11, r11, x7 + str s22, [r12] + add r12, r12, r7 + str d12, [r11] + add r11, r11, x7 + str s26, [r12] + add r12, r12, r7 + str d14, [r11] + str s30, [r12] + add r0, r0, #12 + b WriteEnd + WriteC4: + vst1.32 q0, [r11], x7 + vst1.32 q1, [r11], x7 + vst1.32 q2, [r11], x7 + vst1.32 q3, [r11], x7 + vst1.32 q4, [r11], x7 + vst1.32 q5, [r11], x7 + vst1.32 q6, [r11], x7 + vst1.32 q7, [r11] + add r0, r0, #16 + b WriteEnd + Write4: + // prefetching is not prefered while writing results in spite of cache missings + // you could try prfm pstl2strm + // there are almost no benefits observed though + vst1.32 q0, [r11], x7 + vst1.32 q1, [r11], x7 + vst1.32 q2, [r11], x7 + vst1.32 q3, [r11], x7 + vst1.32 q4, [r11], x7 + vst1.32 q5, [r11], x7 + vst1.32 q6, [r11], x7 + vst1.32 q7, [r11] + add r0, r0, #16 + + WriteEnd: + + subs r8, r8, #1 + bne LoopKsize + + subs r6, r6, #4 + cbz r3, NoStepFowrard + add r3, r3, #16 + NoStepFowrard: + bgt LoopOc + + add sp, sp, #160 + vpop {q4-q7} + pop {r4-r8, r10, r11, pc} +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm32/IndirectGemmInt16to32_8x4.S b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm32/IndirectGemmInt16to32_8x4.S new file mode 100644 index 00000000000..cb77b02dfaf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm32/IndirectGemmInt16to32_8x4.S @@ -0,0 +1,240 @@ +#ifdef __arm__ +#ifndef __aarch64__ + +.text +.align 5 +.global IndirectGemmInt16to32_8x4 +#ifndef __APPLE__ +.type IndirectGemmInt16to32_8x4, %function +#endif + +// void IndirectGemmInt16to32_8x4(int *output, short *input, short *weight, size_t kszie, size_t ic8, size_t oc4, size_t offset); +// r0: output, r1: input, r2: weight, r3: kszie, r4: ic8, r5: oc4, r6: offset +IndirectGemmInt16to32_8x4: + + .macro INIT_ZERO + // we could also use "vmov.s32 q12, #0" to initialize q12 by 0 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + .endm + + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10, lr} + + ldr r4, [sp, #28] + ldr r5, [sp, #32] + ldr r6, [sp, #36] + + vpush {q4-q7} + + LoopOc: + + mov r7, r3 + mov r8, r1 + + LoopKsize: + mov r10, r0 + INIT_ZERO + + // load input + vld1.16 {q0, q1}, [r8]! + // load weight + vld1.16 {q4}, [r2]! + vmull.s16 q8, d8, d0[0] + vmull.s16 q9, d8, d2[0] + // load weight + vld1.16 {q5}, [r2]! + vmlal.s16 q8, d9, d0[1] + vmlal.s16 q9, d9, d2[1] + // load input + vld1.16 {q2, q3}, [r8]! + vmlal.s16 q8, d10, d0[2] + vmlal.s16 q9, d10, d2[2] + vmlal.s16 q8, d11, d0[3] + vmlal.s16 q9, d11, d2[3] + // load weight + vld1.16 {q6, q7}, [r2]! + vmull.s16 q10, d8, d4[0] + vmull.s16 q11, d8, d6[0] + + subs r12, r4, #1 + beq LoopIcEnd + + LoopIc: + + vmlal.s16 q10, d9, d4[1] + vmlal.s16 q11, d9, d6[1] + vmlal.s16 q10, d10, d4[2] + vmlal.s16 q11, d10, d6[2] + vmlal.s16 q10, d11, d4[3] + vmlal.s16 q11, d11, d6[3] + + vmlal.s16 q8, d12, d1[0] + vmlal.s16 q9, d12, d3[0] + vmlal.s16 q8, d13, d1[1] + vmlal.s16 q9, d13, d3[1] + vmlal.s16 q8, d14, d1[2] + vmlal.s16 q9, d14, d3[2] + vmlal.s16 q8, d15, d1[3] + vmlal.s16 q9, d15, d3[3] + // load input + vld1.16 {q0, q1}, [r8]! + vmlal.s16 q10, d12, d5[0] + vmlal.s16 q11, d12, d7[0] + vmlal.s16 q10, d13, d5[1] + vmlal.s16 q11, d13, d7[1] + vmlal.s16 q10, d14, d5[2] + vmlal.s16 q11, d14, d7[2] + vmlal.s16 q10, d15, d5[3] + vmlal.s16 q11, d15, d7[3] + + // load input + vld1.16 {q2, q3}, [r8]! + vmlal.s16 q12, d8, d0[0] + vmlal.s16 q13, d8, d2[0] + vmlal.s16 q12, d9, d0[1] + vmlal.s16 q13, d9, d2[1] + vmlal.s16 q12, d10, d0[2] + vmlal.s16 q13, d10, d2[2] + vmlal.s16 q12, d11, d0[3] + vmlal.s16 q13, d11, d2[3] + + vmlal.s16 q14, d8, d4[0] + vmlal.s16 q15, d8, d6[0] + vmlal.s16 q14, d9, d4[1] + vmlal.s16 q15, d9, d6[1] + vmlal.s16 q14, d10, d4[2] + vmlal.s16 q15, d10, d6[2] + vmlal.s16 q14, d11, d4[3] + vmlal.s16 q15, d11, d6[3] + // load weight + vld1.16 {q4, q5}, [r2]! + vmlal.s16 q12, d12, d1[0] + vmlal.s16 q13, d12, d3[0] + vmlal.s16 q12, d13, d1[1] + vmlal.s16 q13, d13, d3[1] + vmlal.s16 q12, d14, d1[2] + vmlal.s16 q13, d14, d3[2] + vmlal.s16 q12, d15, d1[3] + vmlal.s16 q13, d15, d3[3] + // load input + vld1.16 {q0, q1}, [r8]! + vmlal.s16 q14, d12, d5[0] + vmlal.s16 q15, d12, d7[0] + vmlal.s16 q14, d13, d5[1] + vmlal.s16 q15, d13, d7[1] + vmlal.s16 q14, d14, d5[2] + vmlal.s16 q15, d14, d7[2] + vmlal.s16 q14, d15, d5[3] + vmlal.s16 q15, d15, d7[3] + // load input + vld1.16 {q2, q3}, [r8]! + vmlal.s16 q8, d8, d0[0] + vmlal.s16 q9, d8, d2[0] + vmlal.s16 q8, d9, d0[1] + vmlal.s16 q9, d9, d2[1] + // load weight + vld1.16 {q6, q7}, [r2]! + vmlal.s16 q8, d10, d0[2] + vmlal.s16 q9, d10, d2[2] + vmlal.s16 q8, d11, d0[3] + vmlal.s16 q9, d11, d2[3] + vmlal.s16 q10, d8, d4[0] + vmlal.s16 q11, d8, d6[0] + + subs r12, r12, #1 + bne LoopIc + + LoopIcEnd: + + vmlal.s16 q10, d9, d4[1] + vmlal.s16 q11, d9, d6[1] + vmlal.s16 q10, d10, d4[2] + vmlal.s16 q11, d10, d6[2] + vmlal.s16 q10, d11, d4[3] + vmlal.s16 q11, d11, d6[3] + + vmlal.s16 q8, d12, d1[0] + vmlal.s16 q9, d12, d3[0] + vmlal.s16 q8, d13, d1[1] + vmlal.s16 q9, d13, d3[1] + vmlal.s16 q8, d14, d1[2] + vmlal.s16 q9, d14, d3[2] + vmlal.s16 q8, d15, d1[3] + vmlal.s16 q9, d15, d3[3] + // load input + vld1.16 {q0, q1}, [r8]! + vmlal.s16 q10, d12, d5[0] + vmlal.s16 q11, d12, d7[0] + vmlal.s16 q10, d13, d5[1] + vst1.32 {q8}, [r10], r6 + vmlal.s16 q11, d13, d7[1] + vmlal.s16 q10, d14, d5[2] + vst1.32 {q9}, [r10], r6 + vmlal.s16 q11, d14, d7[2] + vmlal.s16 q10, d15, d5[3] + vmlal.s16 q11, d15, d7[3] + + // load input + vld1.s16 {q2, q3}, [r8]! + vmlal.s16 q12, d8, d0[0] + vmlal.s16 q13, d8, d2[0] + vmlal.s16 q12, d9, d0[1] + vst1.32 {q10}, [r10], r6 + vmlal.s16 q13, d9, d2[1] + vmlal.s16 q12, d10, d0[2] + vst1.32 {q11}, [r10], r6 + vmlal.s16 q13, d10, d2[2] + vmlal.s16 q12, d11, d0[3] + vmlal.s16 q13, d11, d2[3] + + vmlal.s16 q14, d8, d4[0] + vmlal.s16 q15, d8, d6[0] + vmlal.s16 q14, d9, d4[1] + vmlal.s16 q15, d9, d6[1] + vmlal.s16 q14, d10, d4[2] + vmlal.s16 q15, d10, d6[2] + vmlal.s16 q14, d11, d4[3] + vmlal.s16 q15, d11, d6[3] + + vmlal.s16 q12, d12, d1[0] + vmlal.s16 q13, d12, d3[0] + vmlal.s16 q12, d13, d1[1] + vmlal.s16 q13, d13, d3[1] + vmlal.s16 q12, d14, d1[2] + vmlal.s16 q13, d14, d3[2] + vmlal.s16 q12, d15, d1[3] + vmlal.s16 q13, d15, d3[3] + vst1.32 {q12}, [r10], r6 + vmlal.s16 q14, d12, d5[0] + vmlal.s16 q15, d12, d7[0] + vmlal.s16 q14, d13, d5[1] + vmlal.s16 q15, d13, d7[1] + vmlal.s16 q14, d14, d5[2] + vst1.32 {q13}, [r10], r6 + vmlal.s16 q15, d14, d7[2] + vmlal.s16 q14, d15, d5[3] + vmlal.s16 q15, d15, d7[3] + + vst1.32 {q14}, [r10], r6 + vst1.32 {q15}, [r10] + + subs r7, r7, #1 + add r0, r0, #16 + bne LoopKsize + + subs r5, r5, #1 + bne LoopOc + + vpop {q4-q7} + pop {r4-r8, r10, pc} + +#endif +#endif + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm32/IndirectGemmInt8_2x4.S b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm32/IndirectGemmInt8_2x4.S new file mode 100644 index 00000000000..6a23094490c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm32/IndirectGemmInt8_2x4.S @@ -0,0 +1,229 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global IndirectGemmInt8_2x4 +#ifndef __APPLE__ +.type IndirectGemmInt8_2x4, %function +#endif + +// void IndirectGemmInt8_2x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, +// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, +// size_t shift_before, size_t shift_after); +// x0: output, x1: input, r2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset +// x8: input_sum, x10: act_min, x11: act_max, x10: out_zp, x11: out_multiplier, x10: shift_before, x11: shift_after +IndirectGemmInt8_2x4: + + .macro INIT_BIAS + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + .endm + + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #160 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + + mul r5, r4, r5 + mov r4, #1 + + LoopOc: + + mov r8, r4 + mov r12, r1 + + LoopKsize: + INIT_BIAS + mov r11, r0 + + // as some processors do not support sdot intrinsic, we use instruction word + // dp support is stilled judged dymaticly, instruction word is just used to ensure compilation + // according to https://static.docs.arm.com/ddi0596/g/ISA_A64_xml_v86A-2020-03_OPT.pdf + // the instruction word of sdot vd.4s, vn.16b, vm.4b[index] is + // 0100 1111 10Lm mmmm 1110 H0nn nnnd dddd + // mmmmm/nnnnn/ddddd is the number of neon register, HL is the high/low bit of index + + // load input for output 1-2 + vld1.8 {q0, q1}, [r12]! + // load weight for oc 1-2 + vld1.8 {q2, q3}, [r2]! + vmull.s8 q6, d0, d4 + vmull.s8 q7, d0, d6 + vmlal.s8 q6, d1, d5 + vmlal.s8 q7, d1, d7 + vpaddl.s16 q8, q6 + vpaddl.s16 q9, q7 + // load weight for oc 3-4 + vld1.8 {q4, q5}, [r2]! + vmull.s8 q6, d0, d8 + vmull.s8 q7, d0, d10 + vmlal.s8 q6, d1, d9 + vmlal.s8 q7, d1, d11 + + subs x10, x5, #1 + beq LoopIcEnd + + LoopIc: + // load input for output 1 + vld1.8 {q0}, [r12]! + vpadal.s16 q10, q6 + vpadal.s16 q11, q7 + vmull.s8 q6, d2, d4 + vmull.s8 q7, d2, d6 + vmlal.s8 q6, d3, d5 + vmlal.s8 q7, d3, d7 + vld1.8 {q2, q3}, [r2]! + vpadal.s16 q12, q6 + vpadal.s16 q13, q7 + vmull.s8 q6, d2, d8 + vmull.s8 q7, d2, d10 + vmlal.s8 q6, d3, d9 + vmlal.s8 q7, d3, d11 + vld1.8 {q4, q5}, [r2]! + vpadal.s16 q14, q6 + vpadal.s16 q15, q7 + vmull.s8 q6, d0, d4 + vmull.s8 q7, d0, d6 + vmlal.s8 q6, d1, d5 + vmlal.s8 q7, d1, d7 + vld1.8 {q1}, [r12]! + vpadal.s16 q8, q6 + vpadal.s16 q9, q7 + vmull.s8 q6, d0, d8 + vmull.s8 q7, d0, d10 + vmlal.s8 q6, d1, d9 + vmlal.s8 q7, d1, d11 + + subs x10, x10, #1 + bne LoopIc + + LoopIcEnd: + vpadal.s16 q10, q6 + vpadal.s16 q11, q7 + vmull.s8 q6, d2, d4 + vmull.s8 q7, d2, d6 + vmlal.s8 q6, d3, d5 + vmlal.s8 q7, d3, d7 + vpadal.s16 q12, q6 + vpadal.s16 q13, q7 + vmull.s8 q6, d2, d8 + vmull.s8 q7, d2, d10 + vmlal.s8 q6, d3, d9 + vmlal.s8 q7, d3, d11 + vpadal.s16 q14, q6 + vpadal.s16 q15, q7 + + // load sum + ldr r10, [sp, #16] + vld1.32 q0[], [r10]! + vld1.32 q1[], [r10]! + // pairwise add + vpadd.i32 q8, q8, q9 + vpadd.i32 q10, q10, q11 + vpadd.i32 q12, q12, q13 + vpadd.i32 q14, q14, q15 + vpadd.i32 q8, q8, q10 + vpadd.i32 q12, q12, q14 + vsub.i32 q8, q8, q0 + vsub.i32 q12, q12, q1 + cbz r3, NoBias + vld1.32 q2, [r3] + vadd.i32 q8, q8, q2 + vadd.i32 q12, q12, q2 + + NoBias: + ldr r10, [sp, #36] + vdup.32 q3, r10 + vshl.s32 q8, q8, q3 + vshl.s32 q12, q12, q3 + + ldr r10, [sp, #32] + vdup.32 q4, r10 + vqrdmulh.s32 q8, q8, q4 + vqrdmulh.s32 q12, q12, q4 + + ldr r10, [sp, #40] + vdup.32 q5, r10 + vrshl.s32 q8, q8, q5 + vrshl.s32 q12, q12, q5 + + ldr r10, [sp, #28] + vdup.32 q6, r10 + vadd.i32 q8, q8, q6 + vadd.i32 q12, q12, q6 + + ldr r10, [sp, #20] + vdup.32 q0, r10 + vmax.s32 q8, q8, q0 + vmax.s32 q12, q12, q0 + + ldr r10, [sp, #24] + vdup.32 q1, r10 + vmin.s32 q8, q8, q1 + vmin.s32 q12, q12, q1 + + vqmovn.s32 d30, q8 + vqmovn.s32 d31, q12 + vqmovn.s16 d0, q14 + + // prefetching is not prefered while writing results in spite of cache missings + // you could try prfm pstl2strm + WriteStart: + cmp x6, #1 + beq Write1 + cmp x6, #2 + beq Write2 + cmp x6, #3 + beq Write3 + b Write4 + Write1: + vst1.8 {d0[0]}, [x11], x7 + vst1.8 {d0[1]}, [x11] + add r0, r0, #1 + b WriteEnd + Write2: + vst1.16 {d0[0]}, [x11], x7 + vst1.16 {d0[1]}, [x11] + add r0, r0, #2 + b WriteEnd + Write3: + add x14, x11, #2 + vst1.16 {d0[0]}, [x11], x7 + vst1.16 {d0[1]}, [x11] + vst1.8 {d0[0]}, [x14], x7 + vst1.8 {d0[1]}, [x14] + add r0, r0, #3 + b WriteEnd + Write4: + vst1.32 {d0[0]}, [x11], x7 + vst1.32 {d0[1]}, [x11] + add r0, r0, #4 + + WriteEnd: + + subs r8, r8, #1 + bne LoopKsize + + subs r6, r6, #4 + cbz r3, NoStepFowrard + add r3, r3, #16 + NoStepFowrard: + bgt LoopOc + + add sp, sp, #160 + vpop {q4-q7} + pop {r4-r8, r10, r11, pc} +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/IndirectGemmFp16_16x8.S b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/IndirectGemmFp16_16x8.S new file mode 100644 index 00000000000..3c50aa362c2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/IndirectGemmFp16_16x8.S @@ -0,0 +1,720 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global IndirectGemmFp16_16x8 +#ifndef __APPLE__ +.type IndirectGemmFp16_16x8, %function +#endif + +// void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, +// size_t step, size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, size_t relu6); +// x0: output, x1: input, x2: weight, x3: bias, x4: step, x5: ic4, x6: oc8, x7: offset, +// x8:mode, x9: writeC4, x10:relu, x11: relu6 +// compute 8 channel for 16 outputs +IndirectGemmFp16_16x8: + + .macro INIT_BIAS + dup v16.4s, wzr + cbz x3, InitBias + ld1 {v16.8h}, [x3] + InitBias: + mov v17.16b, v16.16b + mov v18.16b, v16.16b + mov v19.16b, v16.16b + mov v20.16b, v16.16b + mov v21.16b, v16.16b + mov v22.16b, v16.16b + mov v23.16b, v16.16b + mov v24.16b, v16.16b + mov v25.16b, v16.16b + mov v26.16b, v16.16b + mov v27.16b, v16.16b + mov v28.16b, v16.16b + mov v29.16b, v16.16b + mov v30.16b, v16.16b + mov v31.16b, v16.16b + .endm + + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ r29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #128 + // performance between storing 4 registers at the same time and seperatly storing them on in-order cores + // is not tested yet + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + + ldr x8, [sp, #0] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + ldr x11, [sp, #24] + + cbnz x8, IndirectGemmStart + // step is one for common convolution, where ic8 should multiply by kernel size + // step is (a+b-1) for F(a,b) in winograd + mul x5, x4, x5 + mov x4, #1 + +IndirectGemmStart: + + LoopOc: + + mov x14, x4 + mov x12, x1 + + LoopKsize: + + mov x15, x0 + INIT_BIAS + // load input for output 1-8 + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64 + // load weight + ld1 {v8.8h, v9.8h}, [x2], #32 + // first 2 steps for output 1 and 3 + fmla v16.8h, v8.8h, v0.h[0] + fmla v18.8h, v8.8h, v1.h[0] + fmla v16.8h, v9.8h, v0.h[1] + fmla v18.8h, v9.8h, v1.h[1] + // load weight + ld1 {v10.8h, v11.8h}, [x2], #32 + // first 2 steps for output 2 and 4 + fmla v17.8h, v8.8h, v0.h[4] + fmla v19.8h, v8.8h, v1.h[4] + fmla v17.8h, v9.8h, v0.h[5] + fmla v19.8h, v9.8h, v1.h[5] + // load input for output 9-16 + // input cache should be refreshed after loading + // ATTENTION: advance is prefered, but advancing too much may lead to invalid prefetching + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64 + // last 2 steps for output 1 and 3 + fmla v16.8h, v10.8h, v0.h[2] + fmla v18.8h, v10.8h, v1.h[2] + fmla v16.8h, v11.8h, v0.h[3] + fmla v18.8h, v11.8h, v1.h[3] + + // check if ic4=1 + subs x13, x5, #1 + beq LoopIcEnd + + LoopIc: + // last 2 steps for output 2 and 4 + fmla v17.8h, v10.8h, v0.h[6] + fmla v19.8h, v10.8h, v1.h[6] + fmla v17.8h, v11.8h, v0.h[7] + fmla v19.8h, v11.8h, v1.h[7] + // steps for output 5-8 + fmla v20.8h, v8.8h, v2.h[0] + fmla v22.8h, v8.8h, v3.h[0] + fmla v20.8h, v9.8h, v2.h[1] + fmla v22.8h, v9.8h, v3.h[1] + fmla v21.8h, v8.8h, v2.h[4] + fmla v23.8h, v8.8h, v3.h[4] + fmla v21.8h, v9.8h, v2.h[5] + fmla v23.8h, v9.8h, v3.h[5] + fmla v20.8h, v10.8h, v2.h[2] + fmla v22.8h, v10.8h, v3.h[2] + fmla v20.8h, v11.8h, v2.h[3] + fmla v22.8h, v11.8h, v3.h[3] + fmla v21.8h, v10.8h, v2.h[6] + fmla v23.8h, v10.8h, v3.h[6] + fmla v21.8h, v11.8h, v2.h[7] + fmla v23.8h, v11.8h, v3.h[7] + // load input for output 1-8 + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64 + // steps for output 9-12 + fmla v24.8h, v8.8h, v4.h[0] + fmla v26.8h, v8.8h, v5.h[0] + fmla v24.8h, v9.8h, v4.h[1] + fmla v26.8h, v9.8h, v5.h[1] + fmla v25.8h, v8.8h, v4.h[4] + fmla v27.8h, v8.8h, v5.h[4] + fmla v25.8h, v9.8h, v4.h[5] + fmla v27.8h, v9.8h, v5.h[5] + fmla v24.8h, v10.8h, v4.h[2] + fmla v26.8h, v10.8h, v5.h[2] + fmla v24.8h, v11.8h, v4.h[3] + fmla v26.8h, v11.8h, v5.h[3] + fmla v25.8h, v10.8h, v4.h[6] + fmla v27.8h, v10.8h, v5.h[6] + fmla v25.8h, v11.8h, v4.h[7] + fmla v27.8h, v11.8h, v5.h[7] + // steps for output 13-16 + fmla v28.8h, v8.8h, v6.h[0] + fmla v30.8h, v8.8h, v7.h[0] + fmla v28.8h, v9.8h, v6.h[1] + fmla v30.8h, v9.8h, v7.h[1] + fmla v29.8h, v8.8h, v6.h[4] + fmla v31.8h, v8.8h, v7.h[4] + fmla v29.8h, v9.8h, v6.h[5] + fmla v31.8h, v9.8h, v7.h[5] + // load weight + ld1 {v8.8h, v9.8h}, [x2], #32 + fmla v28.8h, v10.8h, v6.h[2] + fmla v30.8h, v10.8h, v7.h[2] + fmla v28.8h, v11.8h, v6.h[3] + fmla v30.8h, v11.8h, v7.h[3] + fmla v29.8h, v10.8h, v6.h[6] + fmla v31.8h, v10.8h, v7.h[6] + fmla v29.8h, v11.8h, v6.h[7] + fmla v31.8h, v11.8h, v7.h[7] + // load weight + ld1 {v10.8h, v11.8h}, [x2], #32 + // first 2 steps for output 1-4 + fmla v16.8h, v8.8h, v0.h[0] + fmla v18.8h, v8.8h, v1.h[0] + fmla v16.8h, v9.8h, v0.h[1] + fmla v18.8h, v9.8h, v1.h[1] + fmla v17.8h, v8.8h, v0.h[4] + fmla v19.8h, v8.8h, v1.h[4] + fmla v17.8h, v9.8h, v0.h[5] + fmla v19.8h, v9.8h, v1.h[5] + // load input for output 9-16 + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64 + // last 2 steps for output 1 and 3 + fmla v16.8h, v10.8h, v0.h[2] + fmla v18.8h, v10.8h, v1.h[2] + fmla v16.8h, v11.8h, v0.h[3] + fmla v18.8h, v11.8h, v1.h[3] + + subs x13, x13, #1 + bne LoopIc + + LoopIcEnd: + fmla v17.8h, v10.8h, v0.h[6] + fmla v19.8h, v10.8h, v1.h[6] + fmla v17.8h, v11.8h, v0.h[7] + fmla v19.8h, v11.8h, v1.h[7] + // steps for output 5-8 + fmla v20.8h, v8.8h, v2.h[0] + fmla v22.8h, v8.8h, v3.h[0] + fmla v20.8h, v9.8h, v2.h[1] + fmla v22.8h, v9.8h, v3.h[1] + fmla v21.8h, v8.8h, v2.h[4] + fmla v23.8h, v8.8h, v3.h[4] + fmla v21.8h, v9.8h, v2.h[5] + fmla v23.8h, v9.8h, v3.h[5] + fmla v20.8h, v10.8h, v2.h[2] + fmla v22.8h, v10.8h, v3.h[2] + fmla v20.8h, v11.8h, v2.h[3] + fmla v22.8h, v11.8h, v3.h[3] + fmla v21.8h, v10.8h, v2.h[6] + fmla v23.8h, v10.8h, v3.h[6] + fmla v21.8h, v11.8h, v2.h[7] + fmla v23.8h, v11.8h, v3.h[7] + // steps for output 9-12 + fmla v24.8h, v8.8h, v4.h[0] + fmla v26.8h, v8.8h, v5.h[0] + fmla v24.8h, v9.8h, v4.h[1] + fmla v26.8h, v9.8h, v5.h[1] + fmla v25.8h, v8.8h, v4.h[4] + fmla v27.8h, v8.8h, v5.h[4] + fmla v25.8h, v9.8h, v4.h[5] + fmla v27.8h, v9.8h, v5.h[5] + fmla v24.8h, v10.8h, v4.h[2] + fmla v26.8h, v10.8h, v5.h[2] + fmla v24.8h, v11.8h, v4.h[3] + fmla v26.8h, v11.8h, v5.h[3] + fmla v25.8h, v10.8h, v4.h[6] + fmla v27.8h, v10.8h, v5.h[6] + fmla v25.8h, v11.8h, v4.h[7] + fmla v27.8h, v11.8h, v5.h[7] + // steps for output 13-16 + fmla v28.8h, v8.8h, v6.h[0] + fmla v30.8h, v8.8h, v7.h[0] + fmla v28.8h, v9.8h, v6.h[1] + fmla v30.8h, v9.8h, v7.h[1] + fmla v29.8h, v8.8h, v6.h[4] + fmla v31.8h, v8.8h, v7.h[4] + fmla v29.8h, v9.8h, v6.h[5] + fmla v31.8h, v9.8h, v7.h[5] + fmla v28.8h, v10.8h, v6.h[2] + fmla v30.8h, v10.8h, v7.h[2] + fmla v28.8h, v11.8h, v6.h[3] + fmla v30.8h, v11.8h, v7.h[3] + fmla v29.8h, v10.8h, v6.h[6] + fmla v31.8h, v10.8h, v7.h[6] + fmla v29.8h, v11.8h, v6.h[7] + fmla v31.8h, v11.8h, v7.h[7] + + cbnz x11, Relu6 + cbnz x10, Relu + b WriteStart + Relu6: + movi v9.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v9.8h + fmin v17.8h, v17.8h, v9.8h + fmin v18.8h, v18.8h, v9.8h + fmin v19.8h, v19.8h, v9.8h + fmin v20.8h, v20.8h, v9.8h + fmin v21.8h, v21.8h, v9.8h + fmin v22.8h, v22.8h, v9.8h + fmin v23.8h, v23.8h, v9.8h + fmin v24.8h, v24.8h, v9.8h + fmin v25.8h, v25.8h, v9.8h + fmin v26.8h, v26.8h, v9.8h + fmin v27.8h, v27.8h, v9.8h + fmin v28.8h, v28.8h, v9.8h + fmin v29.8h, v29.8h, v9.8h + fmin v30.8h, v30.8h, v9.8h + fmin v31.8h, v31.8h, v9.8h + Relu: + dup v8.4s, wzr + fmax v16.8h, v16.8h, v8.8h + fmax v17.8h, v17.8h, v8.8h + fmax v18.8h, v18.8h, v8.8h + fmax v19.8h, v19.8h, v8.8h + fmax v20.8h, v20.8h, v8.8h + fmax v21.8h, v21.8h, v8.8h + fmax v22.8h, v22.8h, v8.8h + fmax v23.8h, v23.8h, v8.8h + fmax v24.8h, v24.8h, v8.8h + fmax v25.8h, v25.8h, v8.8h + fmax v26.8h, v26.8h, v8.8h + fmax v27.8h, v27.8h, v8.8h + fmax v28.8h, v28.8h, v8.8h + fmax v29.8h, v29.8h, v8.8h + fmax v30.8h, v30.8h, v8.8h + fmax v31.8h, v31.8h, v8.8h + + WriteStart: + cbnz x9, Write8 + cmp x6, #1 + beq Write1 + cmp x6, #2 + beq Write2 + cmp x6, #3 + beq Write3 + cmp x6, #4 + beq Write4 + cmp x6, #5 + beq Write5 + cmp x6, #6 + beq Write6 + cmp x6, #7 + beq Write7 + b Write8 + // prefetching is not prefered while writing results in spite of cache missings + // you could try prfm pstl2strm + // there are almost no benefits observed though + Write1: + str h16, [x15] + add x15, x15, x7 + str h17, [x15] + add x15, x15, x7 + str h18, [x15] + add x15, x15, x7 + str h19, [x15] + add x15, x15, x7 + str h20, [x15] + add x15, x15, x7 + str h21, [x15] + add x15, x15, x7 + str h22, [x15] + add x15, x15, x7 + str h23, [x15] + add x15, x15, x7 + str h24, [x15] + add x15, x15, x7 + str h25, [x15] + add x15, x15, x7 + str h26, [x15] + add x15, x15, x7 + str h27, [x15] + add x15, x15, x7 + str h28, [x15] + add x15, x15, x7 + str h29, [x15] + add x15, x15, x7 + str h30, [x15] + add x15, x15, x7 + str h31, [x15] + add x0, x0, #2 + b WriteEnd + Write2: + str s16, [x15] + add x15, x15, x7 + str s17, [x15] + add x15, x15, x7 + str s18, [x15] + add x15, x15, x7 + str s19, [x15] + add x15, x15, x7 + str s20, [x15] + add x15, x15, x7 + str s21, [x15] + add x15, x15, x7 + str s22, [x15] + add x15, x15, x7 + str s23, [x15] + add x15, x15, x7 + str s24, [x15] + add x15, x15, x7 + str s25, [x15] + add x15, x15, x7 + str s26, [x15] + add x15, x15, x7 + str s27, [x15] + add x15, x15, x7 + str s28, [x15] + add x15, x15, x7 + str s29, [x15] + add x15, x15, x7 + str s30, [x15] + add x15, x15, x7 + str s31, [x15] + add x0, x0, #4 + b WriteEnd + Write3: + add x17, x15, #4 + str s16, [x15] + add x15, x15, x7 + st1 {v16.h}[2], [x17], x7 + str s17, [x15] + add x15, x15, x7 + st1 {v17.h}[2], [x17], x7 + str s18, [x15] + add x15, x15, x7 + st1 {v18.h}[2], [x17], x7 + str s19, [x15] + add x15, x15, x7 + st1 {v19.h}[2], [x17], x7 + str s20, [x15] + add x15, x15, x7 + st1 {v20.h}[2], [x17], x7 + str s21, [x15] + add x15, x15, x7 + st1 {v21.h}[2], [x17], x7 + str s22, [x15] + add x15, x15, x7 + st1 {v22.h}[2], [x17], x7 + str s23, [x15] + add x15, x15, x7 + st1 {v23.h}[2], [x17], x7 + str s24, [x15] + add x15, x15, x7 + st1 {v24.h}[2], [x17], x7 + str s25, [x15] + add x15, x15, x7 + st1 {v25.h}[2], [x17], x7 + str s26, [x15] + add x15, x15, x7 + st1 {v26.h}[2], [x17], x7 + str s27, [x15] + add x15, x15, x7 + st1 {v27.h}[2], [x17], x7 + str s28, [x15] + add x15, x15, x7 + st1 {v28.h}[2], [x17], x7 + str s29, [x15] + add x15, x15, x7 + st1 {v29.h}[2], [x17], x7 + str s30, [x15] + add x15, x15, x7 + st1 {v30.h}[2], [x17], x7 + str s31, [x15] + st1 {v31.h}[2], [x17] + add x0, x0, #6 + b WriteEnd + Write4: + str d16, [x15] + add x15, x15, x7 + str d17, [x15] + add x15, x15, x7 + str d18, [x15] + add x15, x15, x7 + str d19, [x15] + add x15, x15, x7 + str d20, [x15] + add x15, x15, x7 + str d21, [x15] + add x15, x15, x7 + str d22, [x15] + add x15, x15, x7 + str d23, [x15] + add x15, x15, x7 + str d24, [x15] + add x15, x15, x7 + str d25, [x15] + add x15, x15, x7 + str d26, [x15] + add x15, x15, x7 + str d27, [x15] + add x15, x15, x7 + str d28, [x15] + add x15, x15, x7 + str d29, [x15] + add x15, x15, x7 + str d30, [x15] + add x15, x15, x7 + str d31, [x15] + add x0, x0, #8 + b WriteEnd + Write5: + add x17, x15, #8 + str d16, [x15] + add x15, x15, x7 + st1 {v16.h}[4], [x17], x7 + str d17, [x15] + add x15, x15, x7 + st1 {v17.h}[4], [x17], x7 + str d18, [x15] + add x15, x15, x7 + st1 {v18.h}[4], [x17], x7 + str d19, [x15] + add x15, x15, x7 + st1 {v19.h}[4], [x17], x7 + str d20, [x15] + add x15, x15, x7 + st1 {v20.h}[4], [x17], x7 + str d21, [x15] + add x15, x15, x7 + st1 {v21.h}[4], [x17], x7 + str d22, [x15] + add x15, x15, x7 + st1 {v22.h}[4], [x17], x7 + str d23, [x15] + add x15, x15, x7 + st1 {v23.h}[4], [x17], x7 + str d24, [x15] + add x15, x15, x7 + st1 {v24.h}[4], [x17], x7 + str d25, [x15] + add x15, x15, x7 + st1 {v25.h}[4], [x17], x7 + str d26, [x15] + add x15, x15, x7 + st1 {v26.h}[4], [x17], x7 + str d27, [x15] + add x15, x15, x7 + st1 {v27.h}[4], [x17], x7 + str d28, [x15] + add x15, x15, x7 + st1 {v28.h}[4], [x17], x7 + str d29, [x15] + add x15, x15, x7 + st1 {v29.h}[4], [x17], x7 + str d30, [x15] + add x15, x15, x7 + st1 {v30.h}[4], [x17], x7 + str d31, [x15] + st1 {v31.h}[4], [x17] + add x0, x0, #10 + b WriteEnd + Write6: + add x17, x15, #8 + str d16, [x15] + add x15, x15, x7 + ins v0.s[0], v16.s[2] + str s0, [x17] + add x17, x17, x7 + str d17, [x15] + add x15, x15, x7 + ins v1.s[0], v17.s[2] + str s1, [x17] + add x17, x17, x7 + str d18, [x15] + add x15, x15, x7 + ins v2.s[0], v18.s[2] + str s2, [x17] + add x17, x17, x7 + str d19, [x15] + add x15, x15, x7 + ins v3.s[0], v19.s[2] + str s3, [x17] + add x17, x17, x7 + str d20, [x15] + add x15, x15, x7 + ins v4.s[0], v20.s[2] + str s4, [x17] + add x17, x17, x7 + str d21, [x15] + add x15, x15, x7 + ins v5.s[0], v21.s[2] + str s5, [x17] + add x17, x17, x7 + str d22, [x15] + add x15, x15, x7 + ins v6.s[0], v22.s[2] + str s6, [x17] + add x17, x17, x7 + str d23, [x15] + add x15, x15, x7 + ins v7.s[0], v23.s[2] + str s7, [x17] + add x17, x17, x7 + str d24, [x15] + add x15, x15, x7 + ins v8.s[0], v24.s[2] + str s8, [x17] + add x17, x17, x7 + str d25, [x15] + add x15, x15, x7 + ins v9.s[0], v25.s[2] + str s9, [x17] + add x17, x17, x7 + str d26, [x15] + add x15, x15, x7 + ins v10.s[0], v26.s[2] + str s10, [x17] + add x17, x17, x7 + str d27, [x15] + add x15, x15, x7 + ins v11.s[0], v27.s[2] + str s11, [x17] + add x17, x17, x7 + str d28, [x15] + add x15, x15, x7 + ins v12.s[0], v28.s[2] + str s12, [x17] + add x17, x17, x7 + str d29, [x15] + add x15, x15, x7 + ins v13.s[0], v29.s[2] + str s13, [x17] + add x17, x17, x7 + str d30, [x15] + add x15, x15, x7 + ins v14.s[0], v30.s[2] + str s14, [x17] + add x17, x17, x7 + str d31, [x15] + ins v15.s[0], v31.s[2] + str s15, [x17] + add x0, x0, #12 + b WriteEnd + Write7: + add x17, x15, #8 + add x16, x15, #12 + str d16, [x15] + add x15, x15, x7 + ins v0.s[0], v16.s[2] + str s0, [x17] + add x17, x17, x7 + st1 {v16.h}[6], [x16], x7 + str d17, [x15] + add x15, x15, x7 + ins v1.s[0], v17.s[2] + str s1, [x17] + add x17, x17, x7 + st1 {v17.h}[6], [x16], x7 + str d18, [x15] + add x15, x15, x7 + ins v2.s[0], v18.s[2] + str s2, [x17] + add x17, x17, x7 + st1 {v18.h}[6], [x16], x7 + str d19, [x15] + add x15, x15, x7 + ins v3.s[0], v19.s[2] + str s3, [x17] + add x17, x17, x7 + st1 {v19.h}[6], [x16], x7 + str d20, [x15] + add x15, x15, x7 + ins v4.s[0], v20.s[2] + str s4, [x17] + add x17, x17, x7 + st1 {v20.h}[6], [x16], x7 + str d21, [x15] + add x15, x15, x7 + ins v5.s[0], v21.s[2] + str s5, [x17] + add x17, x17, x7 + st1 {v21.h}[6], [x16], x7 + str d22, [x15] + add x15, x15, x7 + ins v6.s[0], v22.s[2] + str s6, [x17] + add x17, x17, x7 + st1 {v22.h}[6], [x16], x7 + str d23, [x15] + add x15, x15, x7 + ins v7.s[0], v23.s[2] + str s7, [x17] + add x17, x17, x7 + st1 {v23.h}[6], [x16], x7 + str d24, [x15] + add x15, x15, x7 + ins v8.s[0], v24.s[2] + str s8, [x17] + add x17, x17, x7 + st1 {v24.h}[6], [x16], x7 + str d25, [x15] + add x15, x15, x7 + ins v9.s[0], v25.s[2] + str s9, [x17] + add x17, x17, x7 + st1 {v25.h}[6], [x16], x7 + str d26, [x15] + add x15, x15, x7 + ins v10.s[0], v26.s[2] + str s10, [x17] + add x17, x17, x7 + st1 {v26.h}[6], [x16], x7 + str d27, [x15] + add x15, x15, x7 + ins v11.s[0], v27.s[2] + str s11, [x17] + add x17, x17, x7 + st1 {v27.h}[6], [x16], x7 + str d28, [x15] + add x15, x15, x7 + ins v12.s[0], v28.s[2] + str s12, [x17] + add x17, x17, x7 + st1 {v28.h}[6], [x16], x7 + str d29, [x15] + add x15, x15, x7 + ins v13.s[0], v29.s[2] + str s13, [x17] + add x17, x17, x7 + st1 {v29.h}[6], [x16], x7 + str d30, [x15] + add x15, x15, x7 + ins v14.s[0], v30.s[2] + str s14, [x17] + add x17, x17, x7 + st1 {v30.h}[6], [x16], x7 + str d31, [x15] + ins v15.s[0], v31.s[2] + str s15, [x17] + st1 {v31.h}[6], [x16] + add x0, x0, #14 + b WriteEnd + Write8: + st1 {v16.8h}, [x15], x7 + st1 {v17.8h}, [x15], x7 + st1 {v18.8h}, [x15], x7 + st1 {v19.8h}, [x15], x7 + st1 {v20.8h}, [x15], x7 + st1 {v21.8h}, [x15], x7 + st1 {v22.8h}, [x15], x7 + st1 {v23.8h}, [x15], x7 + st1 {v24.8h}, [x15], x7 + st1 {v25.8h}, [x15], x7 + st1 {v26.8h}, [x15], x7 + st1 {v27.8h}, [x15], x7 + st1 {v28.8h}, [x15], x7 + st1 {v29.8h}, [x15], x7 + st1 {v30.8h}, [x15], x7 + st1 {v31.8h}, [x15] + add x0, x0, #16 + + WriteEnd: + subs x14, x14, #1 + bne LoopKsize + + subs x6, x6, #8 + cbz x3, NoStepForward + add x3, x3, #16 + NoStepForward: + bgt LoopOc + + sub sp, sp, #128 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/IndirectGemmFp32_8x8.S b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/IndirectGemmFp32_8x8.S new file mode 100644 index 00000000000..be649b0e588 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/IndirectGemmFp32_8x8.S @@ -0,0 +1,730 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global IndirectGemmFp32_8x8 +#ifndef __APPLE__ +.type IndirectGemmFp32_8x8, %function +#endif + +// void IndirectGemmFp32_8x8(float *output, float *input, float *weight, float *bias, +// size_t kSize, size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, size_t relu6); +// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset +// x8:mode, x9: writeMode, x10: relu, x11:relu6 +// mode = 0 for general convolution, where one conv unit is a row +// mode = 1 for winograd/common gemm, where the total channels of one input is a row +IndirectGemmFp32_8x8: + + .macro INIT_BIAS + dup v16.4s, wzr + dup v17.4s, wzr + cbz x3, InitBias + ld1 {v16.4s, v17.4s}, [x3] + InitBias: + mov v18.16b, v16.16b + mov v19.16b, v17.16b + mov v20.16b, v16.16b + mov v21.16b, v17.16b + mov v22.16b, v16.16b + mov v23.16b, v17.16b + mov v24.16b, v16.16b + mov v25.16b, v17.16b + mov v26.16b, v16.16b + mov v27.16b, v17.16b + mov v28.16b, v16.16b + mov v29.16b, v17.16b + mov v30.16b, v16.16b + mov v31.16b, v17.16b + .endm + + .macro INIT_BIAS_HALF + dup v16.4s, wzr + cbz x3, InitBiasHalf + ld1 {v16.4s}, [x3] + InitBiasHalf: + mov v18.16b, v16.16b + mov v20.16b, v16.16b + mov v22.16b, v16.16b + mov v24.16b, v16.16b + mov v26.16b, v16.16b + mov v28.16b, v16.16b + mov v30.16b, v16.16b + .endm + + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // r19 ~ r29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + + ldr x8, [sp, #0] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + ldr x11, [sp, #24] + + cbnz x8, NoStepShuffle + // step is one for common convolution, where ic8 should multiply by kernel size + // step is (a+b-1) for F(a,b) in winograd + mul x5, x4, x5 + mov x4, #1 + +NoStepShuffle: + // x8 is used to store offset now + // only useful for WriteC4 + mov x8, #16 + mul x8, x8, x4 + +IndirectGemmStart: + + cmp x6, #4 + ble LoopOcHalf + + LoopOc: + + mov x14, x4 + mov x12, x1 + + LoopKsize: + + mov x15, x0 + INIT_BIAS + + // load input for output 1-2 + ld1 {v0.4s, v1.4s}, [x12], #32 + // load weight + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + // step for output 1-2 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v9.4s, v0.s[0] + fmla v18.4s, v8.4s, v1.s[0] + fmla v19.4s, v9.4s, v1.s[0] + // load input for output 3-4 + ld1 {v2.4s, v3.4s}, [x12], #32 + // another step for output 1-2 + fmla v16.4s, v10.4s, v0.s[1] + fmla v17.4s, v11.4s, v0.s[1] + fmla v18.4s, v10.4s, v1.s[1] + fmla v19.4s, v11.4s, v1.s[1] + // load input for output 5-8 + // input cache should be refreshed after loading + // ATTENTION: advance is prefered, but advancing too much may lead to invalid prefetching + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12], #64 + // step for output 3-8 + fmla v20.4s, v8.4s, v2.s[0] + fmla v21.4s, v9.4s, v2.s[0] + fmla v22.4s, v8.4s, v3.s[0] + fmla v23.4s, v9.4s, v3.s[0] + + subs x13, x5, #1 + beq LoopIcEnd + + LoopIc: + fmla v24.4s, v8.4s, v4.s[0] + fmla v25.4s, v9.4s, v4.s[0] + fmla v26.4s, v8.4s, v5.s[0] + fmla v27.4s, v9.4s, v5.s[0] + fmla v28.4s, v8.4s, v6.s[0] + fmla v29.4s, v9.4s, v6.s[0] + fmla v30.4s, v8.4s, v7.s[0] + fmla v31.4s, v9.4s, v7.s[0] + // load weight + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64 + // step for output 3-8 + fmla v20.4s, v10.4s, v2.s[1] + fmla v21.4s, v11.4s, v2.s[1] + fmla v22.4s, v10.4s, v3.s[1] + fmla v23.4s, v11.4s, v3.s[1] + fmla v24.4s, v10.4s, v4.s[1] + fmla v25.4s, v11.4s, v4.s[1] + fmla v26.4s, v10.4s, v5.s[1] + fmla v27.4s, v11.4s, v5.s[1] + fmla v28.4s, v10.4s, v6.s[1] + fmla v29.4s, v11.4s, v6.s[1] + fmla v30.4s, v10.4s, v7.s[1] + fmla v31.4s, v11.4s, v7.s[1] + // another step for output 1-8 + fmla v16.4s, v12.4s, v0.s[2] + fmla v17.4s, v13.4s, v0.s[2] + fmla v18.4s, v12.4s, v1.s[2] + fmla v19.4s, v13.4s, v1.s[2] + fmla v20.4s, v12.4s, v2.s[2] + fmla v21.4s, v13.4s, v2.s[2] + fmla v22.4s, v12.4s, v3.s[2] + fmla v23.4s, v13.4s, v3.s[2] + fmla v24.4s, v12.4s, v4.s[2] + fmla v25.4s, v13.4s, v4.s[2] + fmla v26.4s, v12.4s, v5.s[2] + fmla v27.4s, v13.4s, v5.s[2] + fmla v28.4s, v12.4s, v6.s[2] + fmla v29.4s, v13.4s, v6.s[2] + fmla v30.4s, v12.4s, v7.s[2] + fmla v31.4s, v13.4s, v7.s[2] + // load weight + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + // another step for output 1-8 + fmla v16.4s, v14.4s, v0.s[3] + fmla v17.4s, v15.4s, v0.s[3] + fmla v18.4s, v14.4s, v1.s[3] + fmla v19.4s, v15.4s, v1.s[3] + fmla v20.4s, v14.4s, v2.s[3] + fmla v21.4s, v15.4s, v2.s[3] + fmla v22.4s, v14.4s, v3.s[3] + fmla v23.4s, v15.4s, v3.s[3] + // load input for output 1-4 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x12], #64 + fmla v24.4s, v14.4s, v4.s[3] + fmla v25.4s, v15.4s, v4.s[3] + fmla v26.4s, v14.4s, v5.s[3] + fmla v27.4s, v15.4s, v5.s[3] + fmla v28.4s, v14.4s, v6.s[3] + fmla v29.4s, v15.4s, v6.s[3] + fmla v30.4s, v14.4s, v7.s[3] + fmla v31.4s, v15.4s, v7.s[3] + // load input for output 5-8 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12], #64 + // step for output 1-8 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v9.4s, v0.s[0] + fmla v18.4s, v8.4s, v1.s[0] + fmla v19.4s, v9.4s, v1.s[0] + fmla v16.4s, v10.4s, v0.s[1] + fmla v17.4s, v11.4s, v0.s[1] + fmla v18.4s, v10.4s, v1.s[1] + fmla v19.4s, v11.4s, v1.s[1] + fmla v20.4s, v8.4s, v2.s[0] + fmla v21.4s, v9.4s, v2.s[0] + fmla v22.4s, v8.4s, v3.s[0] + fmla v23.4s, v9.4s, v3.s[0] + + subs x13, x13, #1 + bne LoopIc + + LoopIcEnd: + fmla v24.4s, v8.4s, v4.s[0] + fmla v25.4s, v9.4s, v4.s[0] + fmla v26.4s, v8.4s, v5.s[0] + fmla v27.4s, v9.4s, v5.s[0] + fmla v28.4s, v8.4s, v6.s[0] + fmla v29.4s, v9.4s, v6.s[0] + fmla v30.4s, v8.4s, v7.s[0] + fmla v31.4s, v9.4s, v7.s[0] + // load weight + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64 + // step for output 3-8 + fmla v20.4s, v10.4s, v2.s[1] + fmla v21.4s, v11.4s, v2.s[1] + fmla v22.4s, v10.4s, v3.s[1] + fmla v23.4s, v11.4s, v3.s[1] + fmla v24.4s, v10.4s, v4.s[1] + fmla v25.4s, v11.4s, v4.s[1] + fmla v26.4s, v10.4s, v5.s[1] + fmla v27.4s, v11.4s, v5.s[1] + fmla v28.4s, v10.4s, v6.s[1] + fmla v29.4s, v11.4s, v6.s[1] + fmla v30.4s, v10.4s, v7.s[1] + fmla v31.4s, v11.4s, v7.s[1] + // another step for output 1-8 + fmla v16.4s, v12.4s, v0.s[2] + fmla v17.4s, v13.4s, v0.s[2] + fmla v18.4s, v12.4s, v1.s[2] + fmla v19.4s, v13.4s, v1.s[2] + fmla v20.4s, v12.4s, v2.s[2] + fmla v21.4s, v13.4s, v2.s[2] + fmla v22.4s, v12.4s, v3.s[2] + fmla v23.4s, v13.4s, v3.s[2] + fmla v24.4s, v12.4s, v4.s[2] + fmla v25.4s, v13.4s, v4.s[2] + fmla v26.4s, v12.4s, v5.s[2] + fmla v27.4s, v13.4s, v5.s[2] + fmla v28.4s, v12.4s, v6.s[2] + fmla v29.4s, v13.4s, v6.s[2] + fmla v30.4s, v12.4s, v7.s[2] + fmla v31.4s, v13.4s, v7.s[2] + // another step for output 1-8 + fmla v16.4s, v14.4s, v0.s[3] + fmla v17.4s, v15.4s, v0.s[3] + fmla v18.4s, v14.4s, v1.s[3] + fmla v19.4s, v15.4s, v1.s[3] + fmla v20.4s, v14.4s, v2.s[3] + fmla v21.4s, v15.4s, v2.s[3] + fmla v22.4s, v14.4s, v3.s[3] + fmla v23.4s, v15.4s, v3.s[3] + fmla v24.4s, v14.4s, v4.s[3] + fmla v25.4s, v15.4s, v4.s[3] + fmla v26.4s, v14.4s, v5.s[3] + fmla v27.4s, v15.4s, v5.s[3] + fmla v28.4s, v14.4s, v6.s[3] + fmla v29.4s, v15.4s, v6.s[3] + fmla v30.4s, v14.4s, v7.s[3] + fmla v31.4s, v15.4s, v7.s[3] + // prefetching is not prefered while writing results in spite of cache missings + // you could try prfm pstl2strm + // there are almost no benefits observed though + cbnz x11, Relu6 + cbnz x10, Relu + b WriteStart + Relu6: + movi v1.4s, #6 + scvtf v1.4s, v1.4s + fmin v16.4s, v16.4s ,v1.4s + fmin v17.4s, v17.4s ,v1.4s + fmin v18.4s, v18.4s ,v1.4s + fmin v19.4s, v19.4s ,v1.4s + fmin v20.4s, v20.4s ,v1.4s + fmin v21.4s, v21.4s ,v1.4s + fmin v22.4s, v22.4s ,v1.4s + fmin v23.4s, v23.4s ,v1.4s + fmin v24.4s, v24.4s ,v1.4s + fmin v25.4s, v25.4s ,v1.4s + fmin v26.4s, v26.4s ,v1.4s + fmin v27.4s, v27.4s ,v1.4s + fmin v28.4s, v28.4s ,v1.4s + fmin v29.4s, v29.4s ,v1.4s + fmin v30.4s, v30.4s ,v1.4s + fmin v31.4s, v31.4s ,v1.4s + Relu: + dup v0.4s, wzr + fmax v16.4s, v16.4s ,v0.4s + fmax v17.4s, v17.4s ,v0.4s + fmax v18.4s, v18.4s ,v0.4s + fmax v19.4s, v19.4s ,v0.4s + fmax v20.4s, v20.4s ,v0.4s + fmax v21.4s, v21.4s ,v0.4s + fmax v22.4s, v22.4s ,v0.4s + fmax v23.4s, v23.4s ,v0.4s + fmax v24.4s, v24.4s ,v0.4s + fmax v25.4s, v25.4s ,v0.4s + fmax v26.4s, v26.4s ,v0.4s + fmax v27.4s, v27.4s ,v0.4s + fmax v28.4s, v28.4s ,v0.4s + fmax v29.4s, v29.4s ,v0.4s + fmax v30.4s, v30.4s ,v0.4s + fmax v31.4s, v31.4s ,v0.4s + + WriteStart: + cbnz x9, WriteC4 + cmp x6, #5 + beq Write5 + cmp x6, #6 + beq Write6 + cmp x6, #7 + beq Write7 + b Write8 + Write5: + add x17, x15, #16 + st1 {v16.4s}, [x15], x7 + str s17, [x17] + add x17, x17, x7 + st1 {v18.4s}, [x15], x7 + str s19, [x17] + add x17, x17, x7 + st1 {v20.4s}, [x15], x7 + str s21, [x17] + add x17, x17, x7 + st1 {v22.4s}, [x15], x7 + str s23, [x17] + add x17, x17, x7 + st1 {v24.4s}, [x15], x7 + str s25, [x17] + add x17, x17, x7 + st1 {v26.4s}, [x15], x7 + str s27, [x17] + add x17, x17, x7 + st1 {v28.4s}, [x15], x7 + str s29, [x17] + add x17, x17, x7 + st1 {v30.4s}, [x15] + str s31, [x17] + add x0, x0, #20 + b WriteEnd + Write6: + add x17, x15, #16 + st1 {v16.4s}, [x15], x7 + dup s16, v17.s[1] + stp s17, s16, [x17] + add x17, x17, x7 + st1 {v18.4s}, [x15], x7 + dup s18, v19.s[1] + stp s19, s18, [x17] + add x17, x17, x7 + st1 {v20.4s}, [x15], x7 + dup s20, v21.s[1] + stp s21, s20, [x17] + add x17, x17, x7 + st1 {v22.4s}, [x15], x7 + dup s22, v23.s[1] + stp s23, s22, [x17] + add x17, x17, x7 + st1 {v24.4s}, [x15], x7 + dup s24, v25.s[1] + stp s25, s24, [x17] + add x17, x17, x7 + st1 {v26.4s}, [x15], x7 + dup s26, v27.s[1] + stp s27, s26, [x17] + add x17, x17, x7 + st1 {v28.4s}, [x15], x7 + dup s28, v29.s[1] + stp s29, s28, [x17] + add x17, x17, x7 + st1 {v30.4s}, [x15] + dup s30, v31.s[1] + stp s31, s30, [x17] + add x0, x0, #24 + b WriteEnd + Write7: + add x17, x15, #16 + add x16, x15, #24 + st1 {v16.4s}, [x15], x7 + dup s16, v17.s[1] + stp s17, s16, [x17] + add x17, x17, x7 + st1 {v17.s}[2], [x16], x7 + st1 {v18.4s}, [x15], x7 + dup s18, v19.s[1] + stp s19, s18, [x17] + add x17, x17, x7 + st1 {v19.s}[2], [x16], x7 + st1 {v20.4s}, [x15], x7 + dup s20, v21.s[1] + stp s21, s20, [x17] + add x17, x17, x7 + st1 {v21.s}[2], [x16], x7 + st1 {v22.4s}, [x15], x7 + dup s22, v23.s[1] + stp s23, s22, [x17] + add x17, x17, x7 + st1 {v23.s}[2], [x16], x7 + st1 {v24.4s}, [x15], x7 + dup s24, v25.s[1] + stp s25, s24, [x17] + add x17, x17, x7 + st1 {v25.s}[2], [x16], x7 + st1 {v26.4s}, [x15], x7 + dup s26, v27.s[1] + stp s27, s26, [x17] + add x17, x17, x7 + st1 {v27.s}[2], [x16], x7 + st1 {v28.4s}, [x15], x7 + dup s28, v29.s[1] + stp s29, s28, [x17] + add x17, x17, x7 + st1 {v29.s}[2], [x16], x7 + st1 {v30.4s}, [x15], x7 + dup s30, v31.s[1] + stp s31, s30, [x17] + add x17, x17, x7 + st1 {v31.s}[2], [x16], x7 + add x0, x0, #28 + b WriteEnd + WriteC4: + st1 {v16.4s}, [x15], x7 + st1 {v18.4s}, [x15], x7 + st1 {v20.4s}, [x15], x7 + st1 {v22.4s}, [x15], x7 + st1 {v24.4s}, [x15], x7 + st1 {v26.4s}, [x15], x7 + st1 {v28.4s}, [x15], x7 + st1 {v30.4s}, [x15] + add x15, x8, x0 + st1 {v17.4s}, [x15], x7 + st1 {v19.4s}, [x15], x7 + st1 {v21.4s}, [x15], x7 + st1 {v23.4s}, [x15], x7 + st1 {v25.4s}, [x15], x7 + st1 {v27.4s}, [x15], x7 + st1 {v29.4s}, [x15], x7 + st1 {v31.4s}, [x15] + add x0, x0, #16 + b WriteEnd + Write8: + st1 {v16.4s, v17.4s}, [x15], x7 + st1 {v18.4s, v19.4s}, [x15], x7 + st1 {v20.4s, v21.4s}, [x15], x7 + st1 {v22.4s, v23.4s}, [x15], x7 + st1 {v24.4s, v25.4s}, [x15], x7 + st1 {v26.4s, v27.4s}, [x15], x7 + st1 {v28.4s, v29.4s}, [x15], x7 + st1 {v30.4s, v31.4s}, [x15] + add x0, x0, #32 + + WriteEnd: + + subs x14, x14, #1 + bne LoopKsize + + subs x6, x6, #8 + ble LoopOcEnd + cbz x9, NoStepC4Block + add x0, x0, x8 + NoStepC4Block: + cbz x3, NoStepForward + add x3, x3, #32 + NoStepForward: + cmp x6, #4 + bgt LoopOc + + LoopOcHalf: + mov x18, #32 + + mov x14, x4 + mov x12, x1 + + LoopKsizeHalf: + + mov x15, x0 + INIT_BIAS_HALF + + // load input for output 1-2 + ld1 {v0.4s, v1.4s}, [x12], #32 + // load weight + ld1 {v8.4s}, [x2], x18 + ld1 {v10.4s}, [x2], x18 + // step for output 1-2 + fmla v16.4s, v8.4s, v0.s[0] + fmla v18.4s, v8.4s, v1.s[0] + // load input for output 3-4 + ld1 {v2.4s, v3.4s}, [x12], #32 + // another step for output 1-2 + fmla v16.4s, v10.4s, v0.s[1] + fmla v18.4s, v10.4s, v1.s[1] + // load input for output 5-8 + // input cache should be refreshed after loading + // ATTENTION: advance is prefered, but advancing too much may lead to invalid prefetching + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12], #64 + // step for output 3-8 + fmla v20.4s, v8.4s, v2.s[0] + fmla v22.4s, v8.4s, v3.s[0] + + subs x13, x5, #1 + beq LoopIcEndHalf + + LoopIcHalf: + fmla v24.4s, v8.4s, v4.s[0] + fmla v26.4s, v8.4s, v5.s[0] + fmla v28.4s, v8.4s, v6.s[0] + fmla v30.4s, v8.4s, v7.s[0] + // load weight + ld1 {v12.4s}, [x2], x18 + // step for output 3-8 + fmla v20.4s, v10.4s, v2.s[1] + fmla v22.4s, v10.4s, v3.s[1] + // load weight + ld1 {v14.4s}, [x2], x18 + fmla v24.4s, v10.4s, v4.s[1] + fmla v26.4s, v10.4s, v5.s[1] + fmla v28.4s, v10.4s, v6.s[1] + fmla v30.4s, v10.4s, v7.s[1] + // another step for output 1-8 + fmla v16.4s, v12.4s, v0.s[2] + fmla v18.4s, v12.4s, v1.s[2] + fmla v20.4s, v12.4s, v2.s[2] + fmla v22.4s, v12.4s, v3.s[2] + fmla v24.4s, v12.4s, v4.s[2] + fmla v26.4s, v12.4s, v5.s[2] + fmla v28.4s, v12.4s, v6.s[2] + fmla v30.4s, v12.4s, v7.s[2] + // load weight + ld1 {v8.4s}, [x2], x18 + // another step for output 1-8 + fmla v16.4s, v14.4s, v0.s[3] + fmla v18.4s, v14.4s, v1.s[3] + // load weight + ld1 {v10.4s}, [x2], x18 + fmla v20.4s, v14.4s, v2.s[3] + fmla v22.4s, v14.4s, v3.s[3] + // load input for output 1-4 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x12], #64 + fmla v24.4s, v14.4s, v4.s[3] + fmla v26.4s, v14.4s, v5.s[3] + fmla v28.4s, v14.4s, v6.s[3] + fmla v30.4s, v14.4s, v7.s[3] + // load input for output 5-8 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12], #64 + // step for output 1-8 + fmla v16.4s, v8.4s, v0.s[0] + fmla v18.4s, v8.4s, v1.s[0] + fmla v16.4s, v10.4s, v0.s[1] + fmla v18.4s, v10.4s, v1.s[1] + fmla v20.4s, v8.4s, v2.s[0] + fmla v22.4s, v8.4s, v3.s[0] + + subs x13, x13, #1 + bne LoopIcHalf + + LoopIcEndHalf: + fmla v24.4s, v8.4s, v4.s[0] + fmla v26.4s, v8.4s, v5.s[0] + fmla v28.4s, v8.4s, v6.s[0] + fmla v30.4s, v8.4s, v7.s[0] + // load weight + ld1 {v12.4s}, [x2], x18 + // step for output 3-8 + fmla v20.4s, v10.4s, v2.s[1] + fmla v22.4s, v10.4s, v3.s[1] + // load weight + ld1 {v14.4s}, [x2], x18 + fmla v24.4s, v10.4s, v4.s[1] + fmla v26.4s, v10.4s, v5.s[1] + fmla v28.4s, v10.4s, v6.s[1] + fmla v30.4s, v10.4s, v7.s[1] + // another step for output 1-8 + fmla v16.4s, v12.4s, v0.s[2] + fmla v18.4s, v12.4s, v1.s[2] + fmla v20.4s, v12.4s, v2.s[2] + fmla v22.4s, v12.4s, v3.s[2] + fmla v24.4s, v12.4s, v4.s[2] + fmla v26.4s, v12.4s, v5.s[2] + fmla v28.4s, v12.4s, v6.s[2] + fmla v30.4s, v12.4s, v7.s[2] + // another step for output 1-8 + fmla v16.4s, v14.4s, v0.s[3] + fmla v18.4s, v14.4s, v1.s[3] + fmla v20.4s, v14.4s, v2.s[3] + fmla v22.4s, v14.4s, v3.s[3] + fmla v24.4s, v14.4s, v4.s[3] + fmla v26.4s, v14.4s, v5.s[3] + fmla v28.4s, v14.4s, v6.s[3] + fmla v30.4s, v14.4s, v7.s[3] + + cbnz x11, Relu6Half + cbnz x10, ReluHalf + b WriteStartHalf + Relu6Half: + movi v1.4s, #6 + scvtf v1.4s, v1.4s + fmin v16.4s, v16.4s ,v1.4s + fmin v18.4s, v18.4s ,v1.4s + fmin v20.4s, v20.4s ,v1.4s + fmin v22.4s, v22.4s ,v1.4s + fmin v24.4s, v24.4s ,v1.4s + fmin v26.4s, v26.4s ,v1.4s + fmin v28.4s, v28.4s ,v1.4s + fmin v30.4s, v30.4s ,v1.4s + ReluHalf: + dup v0.4s, wzr + fmax v16.4s, v16.4s ,v0.4s + fmax v18.4s, v18.4s ,v0.4s + fmax v20.4s, v20.4s ,v0.4s + fmax v22.4s, v22.4s ,v0.4s + fmax v24.4s, v24.4s ,v0.4s + fmax v26.4s, v26.4s ,v0.4s + fmax v28.4s, v28.4s ,v0.4s + fmax v30.4s, v30.4s ,v0.4s + + WriteStartHalf: + cbnz x9, Write4 + cmp x6, #1 + beq Write1 + cmp x6, #2 + beq Write2 + cmp x6, #3 + beq Write3 + b Write4 + Write1: + str s16, [x15] + add x15, x15, x7 + str s18, [x15] + add x15, x15, x7 + str s20, [x15] + add x15, x15, x7 + str s22, [x15] + add x15, x15, x7 + str s24, [x15] + add x15, x15, x7 + str s26, [x15] + add x15, x15, x7 + str s28, [x15] + add x15, x15, x7 + str s30, [x15] + add x0, x0, #4 + b WriteEnd + Write2: + dup s17, v16.s[1] + stp s16, s17, [x15] + add x15, x15, x7 + dup s19, v18.s[1] + stp s18, s19, [x15] + add x15, x15, x7 + dup s21, v20.s[1] + stp s20, s21, [x15] + add x15, x15, x7 + dup s23, v22.s[1] + stp s22, s23, [x15] + add x15, x15, x7 + dup s25, v24.s[1] + stp s24, s25, [x15] + add x15, x15, x7 + dup s27, v26.s[1] + stp s26, s27, [x15] + add x15, x15, x7 + dup s29, v28.s[1] + stp s28, s29, [x15] + add x15, x15, x7 + dup s31, v30.s[1] + stp s30, s31, [x15] + add x0, x0, #8 + b WriteEnd + Write3: + add x17, x15, #8 + dup s17, v16.s[1] + stp s16, s17, [x15] + add x15, x15, x7 + st1 {v16.s}[2], [x17], x7 + dup s19, v18.s[1] + stp s18, s19, [x15] + add x15, x15, x7 + st1 {v18.s}[2], [x17], x7 + dup s21, v20.s[1] + stp s20, s21, [x15] + add x15, x15, x7 + st1 {v20.s}[2], [x17], x7 + dup s23, v22.s[1] + stp s22, s23, [x15] + add x15, x15, x7 + st1 {v22.s}[2], [x17], x7 + dup s25, v24.s[1] + stp s24, s25, [x15] + add x15, x15, x7 + st1 {v24.s}[2], [x17], x7 + dup s27, v26.s[1] + stp s26, s27, [x15] + add x15, x15, x7 + st1 {v26.s}[2], [x17], x7 + dup s29, v28.s[1] + stp s28, s29, [x15] + add x15, x15, x7 + st1 {v28.s}[2], [x17], x7 + dup s31, v30.s[1] + stp s30, s31, [x15] + st1 {v30.s}[2], [x17] + add x0, x0, #12 + b WriteEndHalf + Write4: + // prefetching is not prefered while writing results in spite of cache missings + // you could try prfm pstl2strm + // there are almost no benefits observed though + st1 {v16.4s}, [x15], x7 + st1 {v18.4s}, [x15], x7 + st1 {v20.4s}, [x15], x7 + st1 {v22.4s}, [x15], x7 + st1 {v24.4s}, [x15], x7 + st1 {v26.4s}, [x15], x7 + st1 {v28.4s}, [x15], x7 + st1 {v30.4s}, [x15] + add x0, x0, #16 + + WriteEndHalf: + + subs x14, x14, #1 + bne LoopKsizeHalf + +LoopOcEnd: + + sub sp, sp, #128 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/IndirectGemmInt16to32_8x4.S b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/IndirectGemmInt16to32_8x4.S new file mode 100644 index 00000000000..bfad61a3627 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/IndirectGemmInt16to32_8x4.S @@ -0,0 +1,221 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global IndirectGemmInt16to32_8x4 +#ifndef __APPLE__ +.type IndirectGemmInt16to32_8x4, %function +#endif + +// void IndirectGemmInt16to32_8x4(int *output, short *input, short *weight, size_t ksize, size_t ic8, size_t oc4, size_t offset); +// x0: output, x1: input, x2: weight, x3: ksize, x4: ic8, x5: oc4, x6: offset +IndirectGemmInt16to32_8x4: + + .macro INIT_ZERO + dup v28.4s, wzr + mov v29.16b, v28.16b + mov v30.16b, v28.16b + mov v31.16b, v28.16b + .endm + + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + LoopOc: + mov x7, x3 + mov x8, x1 + + LoopKsize: + mov x9, x0 + INIT_ZERO + + // load input + ld1 {v0.8h, v1.8h}, [x8], #32 + // load weight + ld1 {v16.8h}, [x2], #16 + smull v24.4s, v16.4h, v0.h[0] + smull v25.4s, v16.4h, v1.h[0] + // load weight + ld1 {v17.8h}, [x2], #16 + smlal2 v24.4s, v16.8h, v0.h[1] + smlal2 v25.4s, v16.8h, v1.h[1] + // load input + ld1 {v2.8h, v3.8h}, [x8], #32 + smlal v24.4s, v17.4h, v0.h[2] + smlal v25.4s, v17.4h, v1.h[2] + smlal2 v24.4s, v17.8h, v0.h[3] + smlal2 v25.4s, v17.8h, v1.h[3] + // load weight + ld1 {v18.8h, v19.8h}, [x2], #32 + smull v26.4s, v16.4h, v2.h[0] + smull v27.4s, v16.4h, v3.h[0] + + subs x10, x4, #1 + beq LoopIcEnd + + LoopIc: + + smlal2 v26.4s, v16.8h, v2.h[1] + smlal2 v27.4s, v16.8h, v3.h[1] + smlal v26.4s, v17.4h, v2.h[2] + smlal v27.4s, v17.4h, v3.h[2] + smlal2 v26.4s, v17.8h, v2.h[3] + smlal2 v27.4s, v17.8h, v3.h[3] + + smlal v24.4s, v18.4h, v0.h[4] + smlal v25.4s, v18.4h, v1.h[4] + smlal2 v24.4s, v18.8h, v0.h[5] + smlal2 v25.4s, v18.8h, v1.h[5] + smlal v24.4s, v19.4h, v0.h[6] + smlal v25.4s, v19.4h, v1.h[6] + smlal2 v24.4s, v19.8h, v0.h[7] + smlal2 v25.4s, v19.8h, v1.h[7] + // load input + ld1 {v4.8h, v5.8h}, [x8], #32 + smlal v26.4s, v18.4h, v2.h[4] + smlal v27.4s, v18.4h, v3.h[4] + smlal2 v26.4s, v18.8h, v2.h[5] + smlal2 v27.4s, v18.8h, v3.h[5] + smlal v26.4s, v19.4h, v2.h[6] + smlal v27.4s, v19.4h, v3.h[6] + smlal2 v26.4s, v19.8h, v2.h[7] + smlal2 v27.4s, v19.8h, v3.h[7] + + // load input + ld1 {v6.8h, v7.8h}, [x8], #32 + smlal v28.4s, v16.4h, v4.h[0] + smlal v29.4s, v16.4h, v5.h[0] + smlal2 v28.4s, v16.8h, v4.h[1] + smlal2 v29.4s, v16.8h, v5.h[1] + smlal v28.4s, v17.4h, v4.h[2] + smlal v29.4s, v17.4h, v5.h[2] + smlal2 v28.4s, v17.8h, v4.h[3] + smlal2 v29.4s, v17.8h, v5.h[3] + + smlal v30.4s, v16.4h, v6.h[0] + smlal v31.4s, v16.4h, v7.h[0] + smlal2 v30.4s, v16.8h, v6.h[1] + smlal2 v31.4s, v16.8h, v7.h[1] + smlal v30.4s, v17.4h, v6.h[2] + smlal v31.4s, v17.4h, v7.h[2] + smlal2 v30.4s, v17.8h, v6.h[3] + smlal2 v31.4s, v17.8h, v7.h[3] + // load weight + ld1 {v16.8h, v17.8h}, [x2], #32 + smlal v28.4s, v18.4h, v4.h[4] + smlal v29.4s, v18.4h, v5.h[4] + smlal2 v28.4s, v18.8h, v4.h[5] + smlal2 v29.4s, v18.8h, v5.h[5] + smlal v28.4s, v19.4h, v4.h[6] + smlal v29.4s, v19.4h, v5.h[6] + smlal2 v28.4s, v19.8h, v4.h[7] + smlal2 v29.4s, v19.8h, v5.h[7] + // load input + ld1 {v0.8h, v1.8h}, [x8], #32 + smlal v30.4s, v18.4h, v6.h[4] + smlal v31.4s, v18.4h, v7.h[4] + smlal2 v30.4s, v18.8h, v6.h[5] + smlal2 v31.4s, v18.8h, v7.h[5] + smlal v30.4s, v19.4h, v6.h[6] + smlal v31.4s, v19.4h, v7.h[6] + smlal2 v30.4s, v19.8h, v6.h[7] + smlal2 v31.4s, v19.8h, v7.h[7] + // load input + ld1 {v2.8h, v3.8h}, [x8], #32 + smlal v24.4s, v16.4h, v0.h[0] + smlal v25.4s, v16.4h, v1.h[0] + smlal2 v24.4s, v16.8h, v0.h[1] + smlal2 v25.4s, v16.8h, v1.h[1] + // load weight + ld1 {v18.8h, v19.8h}, [x2], #32 + smlal v24.4s, v17.4h, v0.h[2] + smlal v25.4s, v17.4h, v1.h[2] + smlal2 v24.4s, v17.8h, v0.h[3] + smlal2 v25.4s, v17.8h, v1.h[3] + smlal v26.4s, v16.4h, v2.h[0] + smlal v27.4s, v16.4h, v3.h[0] + + subs x10, x10, #1 + bne LoopIc + + LoopIcEnd: + smlal2 v26.4s, v16.8h, v2.h[1] + smlal2 v27.4s, v16.8h, v3.h[1] + smlal v26.4s, v17.4h, v2.h[2] + smlal v27.4s, v17.4h, v3.h[2] + smlal2 v26.4s, v17.8h, v2.h[3] + smlal2 v27.4s, v17.8h, v3.h[3] + + smlal v24.4s, v18.4h, v0.h[4] + smlal v25.4s, v18.4h, v1.h[4] + smlal2 v24.4s, v18.8h, v0.h[5] + smlal2 v25.4s, v18.8h, v1.h[5] + smlal v24.4s, v19.4h, v0.h[6] + smlal v25.4s, v19.4h, v1.h[6] + smlal2 v24.4s, v19.8h, v0.h[7] + smlal2 v25.4s, v19.8h, v1.h[7] + // load input + ld1 {v4.8h, v5.8h}, [x8], #32 + smlal v26.4s, v18.4h, v2.h[4] + smlal v27.4s, v18.4h, v3.h[4] + smlal2 v26.4s, v18.8h, v2.h[5] + st1 {v24.4s}, [x9], x6 + smlal2 v27.4s, v18.8h, v3.h[5] + smlal v26.4s, v19.4h, v2.h[6] + st1 {v25.4s}, [x9], x6 + smlal v27.4s, v19.4h, v3.h[6] + smlal2 v26.4s, v19.8h, v2.h[7] + smlal2 v27.4s, v19.8h, v3.h[7] + + // load input + ld1 {v6.8h, v7.8h}, [x8], #32 + smlal v28.4s, v16.4h, v4.h[0] + smlal v29.4s, v16.4h, v5.h[0] + smlal2 v28.4s, v16.8h, v4.h[1] + smlal2 v29.4s, v16.8h, v5.h[1] + smlal v28.4s, v17.4h, v4.h[2] + st1 {v26.4s}, [x9], x6 + smlal v29.4s, v17.4h, v5.h[2] + smlal2 v28.4s, v17.8h, v4.h[3] + smlal2 v29.4s, v17.8h, v5.h[3] + st1 {v27.4s}, [x9], x6 + smlal v30.4s, v16.4h, v6.h[0] + smlal v31.4s, v16.4h, v7.h[0] + smlal2 v30.4s, v16.8h, v6.h[1] + smlal2 v31.4s, v16.8h, v7.h[1] + smlal v30.4s, v17.4h, v6.h[2] + smlal v31.4s, v17.4h, v7.h[2] + smlal2 v30.4s, v17.8h, v6.h[3] + smlal2 v31.4s, v17.8h, v7.h[3] + smlal v28.4s, v18.4h, v4.h[4] + smlal v29.4s, v18.4h, v5.h[4] + smlal2 v28.4s, v18.8h, v4.h[5] + smlal2 v29.4s, v18.8h, v5.h[5] + smlal v28.4s, v19.4h, v4.h[6] + smlal v29.4s, v19.4h, v5.h[6] + smlal2 v28.4s, v19.8h, v4.h[7] + smlal2 v29.4s, v19.8h, v5.h[7] + smlal v30.4s, v18.4h, v6.h[4] + smlal v31.4s, v18.4h, v7.h[4] + st1 {v28.4s}, [x9], x6 + smlal2 v30.4s, v18.8h, v6.h[5] + smlal2 v31.4s, v18.8h, v7.h[5] + smlal v30.4s, v19.4h, v6.h[6] + st1 {v29.4s}, [x9], x6 + smlal v31.4s, v19.4h, v7.h[6] + smlal2 v30.4s, v19.8h, v6.h[7] + smlal2 v31.4s, v19.8h, v7.h[7] + + st1 {v30.4s}, [x9], x6 + st1 {v31.4s}, [x9] + + subs x7, x7, #1 + add x0, x0, #16 + bne LoopKsize + + subs x5, x5, #1 + bne LoopOc + + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/IndirectGemmInt8_4x4.S b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/IndirectGemmInt8_4x4.S new file mode 100644 index 00000000000..f70495e0e2a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/IndirectGemmInt8_4x4.S @@ -0,0 +1,326 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global IndirectGemmInt8_4x4 +#ifndef __APPLE__ +.type IndirectGemmInt8_4x4, %function +#endif + +// void IndirectGemmInt8_4x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, +// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, +// size_t shift_before, size_t shift_after); +// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset +IndirectGemmInt8_4x4: + + .macro INIT_BIAS + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + .endm + + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // r19 ~ r29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + + ldr x15, [sp] + ldr w8, [sp, #8] + ldr w9, [sp, #16] + ldr w16, [sp, #24] + ldr w17, [sp, #32] + ldr w18, [sp, #40] + ldr w19, [sp, #48] + + mul x5, x4, x5 + mov x4, #1 + + LoopOc: + + mov x10, x4 + mov x12, x1 + + LoopKsize: + INIT_BIAS + mov x11, x0 + + // as some processors do not support sdot intrinsic, we use instruction word + // dp support is stilled judged dymaticly, instruction word is just used to ensure compilation + // according to https://static.docs.arm.com/ddi0596/g/ISA_A64_xml_v86A-2020-03_OPT.pdf + // the instruction word of sdot vd.4s, vn.16b, vm.4b[index] is + // 0100 1111 10Lm mmmm 1110 H0nn nnnd dddd + // mmmmm/nnnnn/ddddd is the number of neon register, HL is the high/low bit of index + + // load input for output 1-8 + ld1 {v0.16b, v1.16b}, [x12], #32 + // load weight + ld1 {v4.16b, v5.16b}, [x2], #32 + // step for output 1-4 + smull v8.8h, v0.8b, v4.8b + smull v9.8h, v0.8b, v5.8b + smlal2 v8.8h, v0.16b, v4.16b + smlal2 v9.8h, v0.16b, v5.16b + // load input for output 9-16 + ld1 {v6.16b, v7.16b}, [x2], #32 + // another step for output 5-8 + smull v12.8h, v1.8b, v4.8b + smull v13.8h, v1.8b, v5.8b + smlal2 v12.8h, v1.16b, v4.16b + smlal2 v13.8h, v1.16b, v5.16b + ld1 {v2.16b, v3.16b}, [x12], #32 + smull v10.8h, v0.8b, v6.8b + smull v11.8h, v0.8b, v7.8b + smlal2 v10.8h, v0.16b, v6.16b + smlal2 v11.8h, v0.16b, v7.16b + saddlp v16.4s, v8.8h + smull v14.8h, v1.8b, v6.8b + smull v15.8h, v1.8b, v7.8b + smlal2 v14.8h, v1.16b, v6.16b + smlal2 v15.8h, v1.16b, v7.16b + saddlp v17.4s, v9.8h + + subs x13, x5, #1 + beq LoopIcEnd + + LoopIc: + // load input for output 1-8 + ld1 {v0.16b, v1.16b}, [x12], #32 + sadalp v18.4s, v10.8h + smull v8.8h, v2.8b, v4.8b + smull v9.8h, v2.8b, v5.8b + sadalp v19.4s, v11.8h + smlal2 v8.8h, v2.16b, v4.16b + smlal2 v9.8h, v2.16b, v5.16b + sadalp v20.4s, v12.8h + smull v10.8h, v2.8b, v6.8b + smull v11.8h, v2.8b, v7.8b + sadalp v21.4s, v13.8h + smlal2 v10.8h, v2.16b, v6.16b + smlal2 v11.8h, v2.16b, v7.16b + sadalp v22.4s, v14.8h + smull v12.8h, v3.8b, v4.8b + smull v13.8h, v3.8b, v5.8b + sadalp v23.4s, v15.8h + smlal2 v12.8h, v3.16b, v4.16b + smlal2 v13.8h, v3.16b, v5.16b + sadalp v24.4s, v8.8h + ld1 {v4.16b, v5.16b}, [x2], #32 + smull v14.8h, v3.8b, v6.8b + smull v15.8h, v3.8b, v7.8b + sadalp v25.4s, v9.8h + smlal2 v14.8h, v3.16b, v6.16b + smlal2 v15.8h, v3.16b, v7.16b + sadalp v26.4s, v10.8h + ld1 {v6.16b, v7.16b}, [x2], #32 + smull v8.8h, v0.8b, v4.8b + smull v9.8h, v0.8b, v5.8b + sadalp v27.4s, v11.8h + smlal2 v8.8h, v0.16b, v4.16b + smlal2 v9.8h, v0.16b, v5.16b + sadalp v28.4s, v12.8h + ld1 {v2.16b, v3.16b}, [x12], #32 + smull v12.8h, v1.8b, v4.8b + smull v13.8h, v1.8b, v5.8b + sadalp v29.4s, v13.8h + smlal2 v12.8h, v1.16b, v4.16b + smlal2 v13.8h, v1.16b, v5.16b + sadalp v30.4s, v14.8h + smull v10.8h, v0.8b, v6.8b + smull v11.8h, v0.8b, v7.8b + sadalp v31.4s, v15.8h + smlal2 v10.8h, v0.16b, v6.16b + smlal2 v11.8h, v0.16b, v7.16b + sadalp v16.4s, v8.8h + smull v14.8h, v1.8b, v6.8b + smull v15.8h, v1.8b, v7.8b + sadalp v17.4s, v9.8h + smlal2 v14.8h, v1.16b, v6.16b + smlal2 v15.8h, v1.16b, v7.16b + + subs x13, x13, #1 + bne LoopIc + + LoopIcEnd: + sadalp v18.4s, v10.8h + smull v8.8h, v2.8b, v4.8b + smull v9.8h, v2.8b, v5.8b + sadalp v19.4s, v11.8h + smlal2 v8.8h, v2.16b, v4.16b + smlal2 v9.8h, v2.16b, v5.16b + sadalp v20.4s, v12.8h + smull v10.8h, v2.8b, v6.8b + smull v11.8h, v2.8b, v7.8b + sadalp v21.4s, v13.8h + smlal2 v10.8h, v2.16b, v6.16b + smlal2 v11.8h, v2.16b, v7.16b + sadalp v22.4s, v14.8h + smull v12.8h, v3.8b, v4.8b + smull v13.8h, v3.8b, v5.8b + sadalp v23.4s, v15.8h + smlal2 v12.8h, v3.16b, v4.16b + smlal2 v13.8h, v3.16b, v5.16b + sadalp v24.4s, v8.8h + smull v14.8h, v3.8b, v6.8b + smull v15.8h, v3.8b, v7.8b + sadalp v25.4s, v9.8h + smlal2 v14.8h, v3.16b, v6.16b + smlal2 v15.8h, v3.16b, v7.16b + sadalp v26.4s, v10.8h + sadalp v27.4s, v11.8h + sadalp v28.4s ,v12.8h + sadalp v29.4s, v13.8h + sadalp v30.4s, v14.8h + sadalp v31.4s, v15.8h + + // load sum + mov x20, x15 + ld1r {v8.4s}, [x20], #4 + ld1r {v9.4s}, [x20], #4 + ld1r {v10.4s}, [x20], #4 + ld1r {v11.4s}, [x20] + // pairwise add + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + cbz x3, NoReadBias + ld1 {v12.4s}, [x3] + NoReadBias: + addp v16.4s, v16.4s, v18.4s + addp v20.4s, v20.4s, v22.4s + addp v24.4s, v24.4s, v26.4s + addp v28.4s, v28.4s, v30.4s + sub v16.4s, v16.4s, v8.4s + sub v20.4s, v20.4s, v9.4s + sub v24.4s, v24.4s, v10.4s + sub v28.4s, v28.4s, v11.4s + add v16.4s, v16.4s, v12.4s + add v20.4s, v20.4s, v12.4s + add v24.4s, v24.4s, v12.4s + add v28.4s, v28.4s, v12.4s + + dup v2.4s, w18 + sqshl v16.4s, v16.4s ,v2.4s + sqshl v20.4s, v20.4s ,v2.4s + sqshl v24.4s, v24.4s ,v2.4s + sqshl v28.4s, v28.4s ,v2.4s + + dup v3.4s, w17 + sqrdmulh v16.4s, v16.4s ,v3.4s + sqrdmulh v20.4s, v20.4s ,v3.4s + sqrdmulh v24.4s, v24.4s ,v3.4s + sqrdmulh v28.4s, v28.4s ,v3.4s + + dup v4.4s, w19 + sqrshl v16.4s, v16.4s ,v4.4s + sqrshl v20.4s, v20.4s ,v4.4s + sqrshl v24.4s, v24.4s ,v4.4s + sqrshl v28.4s, v28.4s ,v4.4s + + dup v5.4s, w16 + add v16.4s, v16.4s ,v5.4s + add v20.4s, v20.4s ,v5.4s + add v24.4s, v24.4s ,v5.4s + add v28.4s, v28.4s ,v5.4s + + dup v0.4s, w8 + smax v16.4s, v16.4s ,v0.4s + smax v20.4s, v20.4s ,v0.4s + smax v24.4s, v24.4s ,v0.4s + smax v28.4s, v28.4s ,v0.4s + + dup v1.4s, w9 + smin v16.4s, v16.4s ,v1.4s + smin v20.4s, v20.4s ,v1.4s + smin v24.4s, v24.4s ,v1.4s + smin v28.4s, v28.4s ,v1.4s + + sqxtn v13.4h, v16.4s + sqxtn2 v13.8h, v20.4s + sqxtn v15.8b, v13.8h + sqxtn v14.4h, v24.4s + sqxtn2 v14.8h, v28.4s + sqxtn2 v15.16b, v14.8h + + // prefetching is not prefered while writing results in spite of cache missings + // you could try prfm pstl2strm + WriteStart: + cmp x6, #1 + beq Write1 + cmp x6, #2 + beq Write2 + cmp x6, #3 + beq Write3 + b Write4 + Write1: + st1 {v15.b}[0], [x11], x7 + st1 {v15.b}[4], [x11], x7 + st1 {v15.b}[8], [x11], x7 + st1 {v15.b}[12], [x11] + add x0, x0, #1 + b WriteEnd + Write2: + st1 {v15.h}[0], [x11], x7 + st1 {v15.h}[2], [x11], x7 + st1 {v15.h}[4], [x11], x7 + st1 {v15.h}[6], [x11] + add x0, x0, #2 + b WriteEnd + Write3: + add x14, x11, #2 + st1 {v15.h}[0], [x11], x7 + st1 {v15.b}[2], [x14], x7 + st1 {v15.h}[2], [x11], x7 + st1 {v15.b}[6], [x14], x7 + st1 {v15.h}[4], [x11], x7 + st1 {v15.b}[10], [x14], x7 + st1 {v15.h}[6], [x11] + st1 {v15.b}[14], [x14] + add x0, x0, #3 + b WriteEnd + Write4: + st1 {v15.s}[0], [x11], x7 + st1 {v15.s}[1], [x11], x7 + st1 {v15.s}[2], [x11], x7 + st1 {v15.s}[3], [x11] + add x0, x0, #4 + + WriteEnd: + + subs x10, x10, #1 + bne LoopKsize + + subs x6, x6, #4 + cbz x3, NoStepFowrard + add x3, x3, #16 + NoStepFowrard: + bgt LoopOc + + sub sp, sp, #144 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/bias_add.S b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/bias_add.S new file mode 100644 index 00000000000..181de0de726 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/bias_add.S @@ -0,0 +1,82 @@ + +#ifdef __aarch64__ + .text + .align 5 + //.p2align 5,,15 + .global BiasAdd +#ifndef __APPLE__ + .type BiasAdd, %function +#endif + + + +//void BiasAdd(const float* bias, float* data, size_t oc4, size_t plan_size) + +//Auto: x0:bias, x1: data, x2:oc4,x3: plan_size, + +BiasAdd: +cmp x2, #0 +beq BiasAddEnd + +cmp x3, #0 +beq BiasAddEnd + +LoopOc4: +ld1 {v0.4s}, [x0], #16 +mov x6, x3 +mov x5, x1 + +Loop16LineIn: +cmp x6, #4 +blt L4 +sub x6, x6, #4 + +ld1 {v1.4s, v2.4s}, [x5], #32 + +fadd v5.4s, v0.4s, v1.4s +fadd v6.4s, v0.4s, v2.4s +ld1 {v3.4s, v4.4s}, [x5], #32 + +cmp x6, #4 +blt Loop16LineOut + +Loop16: +st1 {v5.4s, v6.4s}, [x1], #32 +fadd v7.4s, v0.4s, v3.4s +fadd v8.4s, v0.4s, v4.4s +ld1 {v1.4s, v2.4s}, [x5], #32 + +st1 {v7.4s, v8.4s}, [x1], #32 +fadd v5.4s, v0.4s, v1.4s +fadd v6.4s, v0.4s, v2.4s +ld1 {v3.4s, v4.4s}, [x5], #32 + +sub x6, x6, #4 +cmp x6, #4 +bge Loop16 + +Loop16LineOut: +st1 {v5.4s, v6.4s}, [x1], #32 +fadd v7.4s, v0.4s, v3.4s +fadd v8.4s, v0.4s, v4.4s + +st1 {v7.4s, v8.4s}, [x1], #32 + +L4: +cmp x6, #0 +beq Loop16LineEnd +Loop4: +ld1 {v1.4s}, [x5], #16 +fadd v2.4s, v1.4s, v0.4s +subs x6, x6, #1 +st1 {v2.4s}, [x1], #16 +bne Loop4 + +Loop16LineEnd: +subs x2, x2, #1 +bne LoopOc4 + +BiasAddEnd: + +ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/bias_add_relu.S b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/bias_add_relu.S new file mode 100644 index 00000000000..f9e4eccc69f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/bias_add_relu.S @@ -0,0 +1,94 @@ + +#ifdef __aarch64__ + .text + .align 5 + //.p2align 5,,15 + .global BiasAddRelu +#ifndef __APPLE__ + .type BiasAddRelu, %function +#endif + + +//void BiasAddRelu(const float* bias, float* data, size_t oc4, size_t plan_size) + +//Auto: x0:bias, x1: data, x2:oc4,x3: plan_size, + +BiasAddRelu: +cmp x2, #0 +beq BiasAddEnd + +cmp x3, #0 +beq BiasAddEnd + +dup v16.4s, wzr + +LoopOc4: +ld1 {v0.4s}, [x0], #16 +mov x6, x3 +mov x5, x1 + +Loop16LineIn: +cmp x6, #4 +blt L4 +sub x6, x6, #4 + +ld1 {v1.4s, v2.4s}, [x5], #32 + +fadd v21.4s, v0.4s, v1.4s +fadd v22.4s, v0.4s, v2.4s +ld1 {v3.4s, v4.4s}, [x5], #32 + +fmax v23.4s, v21.4s, v16.4s +fmax v24.4s, v22.4s, v16.4s + +cmp x6, #4 +blt Loop16LineOut + +Loop16: +st1 {v23.4s, v24.4s}, [x1], #32 +fadd v25.4s, v0.4s, v3.4s +fadd v26.4s, v0.4s, v4.4s +ld1 {v1.4s, v2.4s}, [x5], #32 + +fmax v27.4s, v25.4s, v16.4s +fmax v28.4s, v26.4s, v16.4s +fadd v21.4s, v0.4s, v1.4s +fadd v22.4s, v0.4s, v2.4s + +st1 {v27.4s, v28.4s}, [x1], #32 +ld1 {v3.4s, v4.4s}, [x5], #32 +fmax v23.4s, v21.4s, v16.4s +fmax v24.4s, v22.4s, v16.4s +sub x6, x6, #4 +cmp x6, #4 +bge Loop16 + +Loop16LineOut: +st1 {v23.4s, v24.4s}, [x1], #32 +fadd v25.4s, v0.4s, v3.4s +fadd v26.4s, v0.4s, v4.4s + +fmax v27.4s, v25.4s, v16.4s +fmax v28.4s, v26.4s, v16.4s +st1 {v27.4s, v28.4s}, [x1], #32 + +L4: +cmp x6, #0 +beq Loop16LineEnd +Loop4: +ld1 {v1.4s}, [x5], #16 +fadd v1.4s, v1.4s, v0.4s +fmax v1.4s, v1.4s, v16.4s + +subs x6, x6, #1 +st1 {v1.4s}, [x1], #16 +bne Loop4 + +Loop16LineEnd: +subs x2, x2, #1 +bne LoopOc4 + +BiasAddEnd: + +ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/bias_add_relu6.S b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/bias_add_relu6.S new file mode 100644 index 00000000000..77c563a8126 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/bias_add_relu6.S @@ -0,0 +1,113 @@ + +#ifdef __aarch64__ + .text + .align 5 + //.p2align 5,,15 + .global BiasAddRelu6 +#ifndef __APPLE__ + .type BiasAddRelu6, %function +#endif + + + +//void BiasAddRelu6(const float* bias, float* data, size_t oc4, size_t plan_size) + +//Auto: x0:bias, x1: data, x2:oc4,x3: plan_size, + +BiasAddRelu6: +cmp x2, #0 +beq BiasAddEnd + +cmp x3, #0 +beq BiasAddEnd + +dup v16.4s, wzr +movi v17.4s, #6 +scvtf v17.4s, v17.4s + +LoopOc4: +ld1 {v0.4s}, [x0], #16 +mov x6, x3 +mov x5, x1 + +Loop16LineIn: +cmp x6, #4 +blt L4 +sub x6, x6, #4 + +ld1 {v1.4s, v2.4s}, [x5], #32 + +fadd v21.4s, v0.4s, v1.4s +fadd v22.4s, v0.4s, v2.4s +ld1 {v3.4s, v4.4s}, [x5], #32 + +fmax v23.4s, v21.4s, v16.4s +fmax v24.4s, v22.4s, v16.4s + + + +cmp x6, #4 +blt Loop16LineOut + +Loop16: +fmin v23.4s, v23.4s, v17.4s +fmin v24.4s, v24.4s, v17.4s +fadd v25.4s, v0.4s, v3.4s +fadd v26.4s, v0.4s, v4.4s +ld1 {v1.4s, v2.4s}, [x5], #32 + +st1 {v23.4s, v24.4s}, [x1], #32 +fmax v27.4s, v25.4s, v16.4s +fmax v28.4s, v26.4s, v16.4s +fadd v21.4s, v0.4s, v1.4s +fadd v22.4s, v0.4s, v2.4s + +fmin v27.4s, v27.4s, v17.4s +fmin v28.4s, v28.4s, v17.4s +fmax v23.4s, v21.4s, v16.4s +fmax v24.4s, v22.4s, v16.4s +ld1 {v3.4s, v4.4s}, [x5], #32 + +st1 {v27.4s, v28.4s}, [x1], #32 + + +sub x6, x6, #4 +cmp x6, #4 +bge Loop16 + +Loop16LineOut: +fmin v23.4s, v23.4s, v17.4s +fmin v24.4s, v24.4s, v17.4s +fadd v25.4s, v0.4s, v3.4s +fadd v26.4s, v0.4s, v4.4s + +fmax v27.4s, v25.4s, v16.4s +fmax v28.4s, v26.4s, v16.4s +st1 {v23.4s, v24.4s}, [x1], #32 + +fmin v27.4s, v27.4s, v17.4s +fmin v28.4s, v28.4s, v17.4s + +st1 {v27.4s, v28.4s}, [x1], #32 + +L4: +cmp x6, #0 +beq Loop16LineEnd +Loop4: +ld1 {v1.4s}, [x5], #16 +fadd v1.4s, v1.4s, v0.4s +fmax v1.4s, v1.4s, v16.4s +fmin v1.4s, v1.4s, v17.4s + +subs x6, x6, #1 +st1 {v1.4s}, [x1], #16 +bne Loop4 + +Loop16LineEnd: +subs x2, x2, #1 +bne LoopOc4 + +BiasAddEnd: + +ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matmul.s b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matmul.s new file mode 100644 index 00000000000..17dddeb355d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matmul.s @@ -0,0 +1,315 @@ +#ifdef __aarch64__ + .text + .align 5 + .global MatMulFloatNeon64 +#ifndef __APPLE__ + .type MatMulFloatNeon64, %function +#endif + +// A: LM [row_8 * depth] col_8_major +// B: RM [depth * col_8] row_8_major +// C: A*B [row_8 * col_8] col_8x8_major +// A * B -> [8 * depth] * [depth * 8] -> [8 * 4] * [4 * 8] or [8 * 1] * [1 * 8] +/////////////////////////////////////////////////////////////////////////////// +//CommLoopMul RM 1x8 block +// /-----------------------------------------\ +// |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| +// \-----------------------------------------/ +// LM 8x1 block +// /---------------------\ /-----------------------------------------\ +// | v0.s[0] | |v16.s[0] ... v30.s[0]| +// | ... | | ... ... | +// | v0.s[3] | |v16.s[3] ... v30.s[3]| +// | v1.s[0] | |v17.s[0] ... v31.s[0]| +// | ... | | ... ... | +// | v1.s[3] | |v17.s[3] ... v31.s[3]| +// \---------------------/ \-----------------------------------------/ +// accumulators 8x8 block +// +/////////////////////////////////////////////////////////////////////////////// +//OptLoopMul4 RHS 1x8 block +// /--------------------------------------------\ +// |v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] | +// |v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]| +// |v12.s[0] ... v12.s[3] v13.s[0] ... v13.s[3]| +// |v14.s[0] ... v14.s[3] v15.s[0] ... v15.s[3]| +// \--------------------------------------------/ +// LM 8x4 block +// /---------------------------------\ /--------------------------------------------\ +// | v0.s[0] v2.s[0] v4.s[0] v6.s[0] | |v16.s[0] ... v30.s[0]| +// | ... ... ... ... | | ... ... | +// | v0.s[3] v2.s[3] v4.s[3] v6.s[3] | |v16.s[3] ... v30.s[3]| +// | v1.s[0] v3.s[0] v5.s[0] v7.s[0] | |v17.s[0] ... v31.s[0]| +// | ... ... ... ... | | ... ... | +// | v1.s[3] v3.s[3] v5.s[3] v7.s[3] | |v17.s[3] ... v31.s[3]| +// \---------------------------------/ \--------------------------------------------/ +// accumulators 8x8 block +///////////////////////////////////////////////////////////////////////////////// +// +// void MatMulFloatNeon64(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth, int row, int col) +// x0: a +// x1: b +// x2: c +// x3: bias +// v0.s[0]: maxf +// v1.s[0]: minf +// w4: depth +// w5: row +// w6: col + +MatMulFloatNeon64: + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + + mov w7, v0.s[0] + mov w8, v1.s[0] + mov w9, 0 // row counter + mov w10, 0 // col counter + mov w18, #32 + mul w15, w4, w18 // the stride of a or b + mul w16, w6, w18 // the stride of c + +L1: + cmp w9, w5 + beq End1 + + mov w10, 0 // reset col counter + mov x12, x1 // reload b ptr + mov x17, x2 // reload current c ptr + mov x14, x3 // reload bias ptr +L2: + cmp w10, w6 + beq End2 + + mov x11, x0 // reload a ptr + mov w13, w4 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + +OptLoopMul4: + cmp w13, #4 + blt CommLoopMul + + ld1 {v0.4s}, [x11], #16 + ld1 {v8.4s}, [x12], #16 + fmla v16.4s, v0.4s, v8.s[0] + fmla v18.4s, v0.4s, v8.s[1] + ld1 {v1.4s}, [x11], #16 + fmla v20.4s, v0.4s, v8.s[2] + fmla v22.4s, v0.4s, v8.s[3] + ld1 {v9.4s}, [x12], #16 + fmla v25.4s, v1.4s, v9.s[0] + fmla v27.4s, v1.4s, v9.s[1] + fmla v29.4s, v1.4s, v9.s[2] + fmla v31.4s, v1.4s, v9.s[3] + ld1 {v2.4s}, [x11], #16 + ld1 {v3.4s}, [x11], #16 + fmla v24.4s, v0.4s, v9.s[0] + fmla v26.4s, v0.4s, v9.s[1] + fmla v28.4s, v0.4s, v9.s[2] + fmla v30.4s, v0.4s, v9.s[3] + fmla v17.4s, v1.4s, v8.s[0] + fmla v19.4s, v1.4s, v8.s[1] + fmla v21.4s, v1.4s, v8.s[2] + fmla v23.4s, v1.4s, v8.s[3] + ld1 {v10.4s}, [x12], #16 + ld1 {v11.4s}, [x12], #16 + fmla v16.4s, v2.4s, v10.s[0] + fmla v18.4s, v2.4s, v10.s[1] + fmla v20.4s, v2.4s, v10.s[2] + fmla v22.4s, v2.4s, v10.s[3] + fmla v25.4s, v3.4s, v11.s[0] + fmla v27.4s, v3.4s, v11.s[1] + fmla v29.4s, v3.4s, v11.s[2] + fmla v31.4s, v3.4s, v11.s[3] + ld1 {v4.4s}, [x11], #16 + ld1 {v5.4s}, [x11], #16 + fmla v24.4s, v2.4s, v11.s[0] + fmla v26.4s, v2.4s, v11.s[1] + fmla v28.4s, v2.4s, v11.s[2] + fmla v30.4s, v2.4s, v11.s[3] + fmla v17.4s, v3.4s, v10.s[0] + fmla v19.4s, v3.4s, v10.s[1] + fmla v21.4s, v3.4s, v10.s[2] + fmla v23.4s, v3.4s, v10.s[3] + ld1 {v12.4s}, [x12], #16 + ld1 {v13.4s}, [x12], #16 + fmla v16.4s, v4.4s, v12.s[0] + fmla v18.4s, v4.4s, v12.s[1] + fmla v20.4s, v4.4s, v12.s[2] + fmla v22.4s, v4.4s, v12.s[3] + fmla v25.4s, v5.4s, v13.s[0] + fmla v27.4s, v5.4s, v13.s[1] + fmla v29.4s, v5.4s, v13.s[2] + fmla v31.4s, v5.4s, v13.s[3] + ld1 {v6.4s}, [x11], #16 + ld1 {v7.4s}, [x11], #16 + fmla v24.4s, v4.4s, v13.s[0] + fmla v26.4s, v4.4s, v13.s[1] + fmla v28.4s, v4.4s, v13.s[2] + fmla v30.4s, v4.4s, v13.s[3] + fmla v17.4s, v5.4s, v12.s[0] + fmla v19.4s, v5.4s, v12.s[1] + fmla v21.4s, v5.4s, v12.s[2] + fmla v23.4s, v5.4s, v12.s[3] + ld1 {v14.4s}, [x12], #16 + ld1 {v15.4s}, [x12], #16 + fmla v16.4s, v6.4s, v14.s[0] + fmla v18.4s, v6.4s, v14.s[1] + fmla v20.4s, v6.4s, v14.s[2] + fmla v22.4s, v6.4s, v14.s[3] + fmla v25.4s, v7.4s, v15.s[0] + fmla v27.4s, v7.4s, v15.s[1] + fmla v29.4s, v7.4s, v15.s[2] + fmla v31.4s, v7.4s, v15.s[3] + fmla v24.4s, v6.4s, v15.s[0] + fmla v26.4s, v6.4s, v15.s[1] + fmla v28.4s, v6.4s, v15.s[2] + fmla v30.4s, v6.4s, v15.s[3] + fmla v17.4s, v7.4s, v14.s[0] + fmla v19.4s, v7.4s, v14.s[1] + fmla v21.4s, v7.4s, v14.s[2] + fmla v23.4s, v7.4s, v14.s[3] + subs w13, w13, #4 + b OptLoopMul4 + +CommLoopMul: + cmp w13, #1 + blt Bias + ld1 {v0.4s}, [x11], #16 + ld1 {v2.4s}, [x12], #16 + fmla v16.4s, v0.4s, v2.s[0] + fmla v18.4s, v0.4s, v2.s[1] + ld1 {v1.4s}, [x11], #16 + fmla v20.4s, v0.4s, v2.s[2] + fmla v22.4s, v0.4s, v2.s[3] + ld1 {v3.4s}, [x12], #16 + fmla v25.4s, v1.4s, v3.s[0] + fmla v27.4s, v1.4s, v3.s[1] + fmla v29.4s, v1.4s, v3.s[2] + fmla v31.4s, v1.4s, v3.s[3] + fmla v24.4s, v0.4s, v3.s[0] + fmla v26.4s, v0.4s, v3.s[1] + fmla v28.4s, v0.4s, v3.s[2] + fmla v30.4s, v0.4s, v3.s[3] + fmla v17.4s, v1.4s, v2.s[0] + fmla v19.4s, v1.4s, v2.s[1] + fmla v21.4s, v1.4s, v2.s[2] + fmla v23.4s, v1.4s, v2.s[3] + subs w13, w13, #1 + b CommLoopMul + +Bias: + ld1 {v0.4s}, [x14], #16 + ld1 {v1.4s}, [x14], #16 + dup v2.4s, v0.s[0] + fadd v16.4s, v16.4s, v2.4s + fadd v17.4s, v17.4s, v2.4s + dup v3.4s, v0.s[1] + fadd v18.4s, v18.4s, v3.4s + fadd v19.4s, v19.4s, v3.4s + dup v4.4s, v0.s[2] + fadd v20.4s, v20.4s, v4.4s + fadd v21.4s, v21.4s, v4.4s + dup v5.4s, v0.s[3] + fadd v22.4s, v22.4s, v5.4s + fadd v23.4s, v23.4s, v5.4s + dup v2.4s, v1.s[0] + fadd v24.4s, v24.4s, v2.4s + fadd v25.4s, v25.4s, v2.4s + dup v3.4s, v1.s[1] + fadd v26.4s, v26.4s, v3.4s + fadd v27.4s, v27.4s, v3.4s + dup v4.4s, v1.s[2] + fadd v28.4s, v28.4s, v4.4s + fadd v29.4s, v29.4s, v4.4s + dup v5.4s, v1.s[3] + fadd v30.4s, v30.4s, v5.4s + fadd v31.4s, v31.4s, v5.4s + +Relu: + dup v15.4s, w7 + dup v14.4s, w8 + fmax v16.4s, v16.4s, v14.4s + fmax v17.4s, v17.4s, v14.4s + fmax v18.4s, v18.4s, v14.4s + fmax v19.4s, v19.4s, v14.4s + fmax v20.4s, v20.4s, v14.4s + fmax v21.4s, v21.4s, v14.4s + fmax v22.4s, v22.4s, v14.4s + fmax v23.4s, v23.4s, v14.4s + fmax v24.4s, v24.4s, v14.4s + fmax v25.4s, v25.4s, v14.4s + fmax v26.4s, v26.4s, v14.4s + fmax v27.4s, v27.4s, v14.4s + fmax v28.4s, v28.4s, v14.4s + fmax v29.4s, v29.4s, v14.4s + fmax v30.4s, v30.4s, v14.4s + fmax v31.4s, v31.4s, v14.4s + + fmin v16.4s, v16.4s, v15.4s + fmin v17.4s, v17.4s, v15.4s + fmin v18.4s, v18.4s, v15.4s + fmin v19.4s, v19.4s, v15.4s + fmin v20.4s, v20.4s, v15.4s + fmin v20.4s, v20.4s, v15.4s + fmin v21.4s, v21.4s, v15.4s + fmin v22.4s, v22.4s, v15.4s + fmin v23.4s, v23.4s, v15.4s + fmin v24.4s, v24.4s, v15.4s + fmin v25.4s, v25.4s, v15.4s + fmin v26.4s, v26.4s, v15.4s + fmin v27.4s, v27.4s, v15.4s + fmin v28.4s, v28.4s, v15.4s + fmin v29.4s, v29.4s, v15.4s + fmin v30.4s, v30.4s, v15.4s + fmin v31.4s, v31.4s, v15.4s + +TransToOut: + st1 {v16.4s}, [x17], #16 + st1 {v17.4s}, [x17], #16 + st1 {v18.4s}, [x17], #16 + st1 {v19.4s}, [x17], #16 + st1 {v20.4s}, [x17], #16 + st1 {v21.4s}, [x17], #16 + st1 {v22.4s}, [x17], #16 + st1 {v23.4s}, [x17], #16 + st1 {v24.4s}, [x17], #16 + st1 {v25.4s}, [x17], #16 + st1 {v26.4s}, [x17], #16 + st1 {v27.4s}, [x17], #16 + st1 {v28.4s}, [x17], #16 + st1 {v29.4s}, [x17], #16 + st1 {v30.4s}, [x17], #16 + st1 {v31.4s}, [x17], #16 + + add w10, w10, #8 // col+=8 + b L2 + +End2: + add x0, x0, x15 // stride a ptr + add x2, x2, x16 // stride c ptr + add w9, w9, #8 // row+=8 + b L1 + +End1: + sub sp, sp, #128 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif \ No newline at end of file diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matrix_add.S b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matrix_add.S new file mode 100644 index 00000000000..d3611903611 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matrix_add.S @@ -0,0 +1,103 @@ + +#ifdef __aarch64__ + .text + .align 5 + //.p2align 5,,15 + .global MatrixAdd +#ifndef __APPLE__ + .type MatrixAdd, %function +#endif + + + +//void MatrixAdd(const float* matDataA, const float* matDataB, float* matDataC, +// size_t aStride, size_t bStride, size_t cStride, size_t width, size_t height) + +//Auto: x0: matDataA, x1:matDataB, x2:matDatac, +//x3:aStride, x4:bStride, x5:cStride, x6:width, x7:height + +MatrixAdd: +mov x12, #4 //sizeof(float) +mul x3, x12, x3 +mul x4, x12, x4 +mul x5, x12, x5 + +loopH: +mov x8, x0 +mov x9, x1 +mov x10, x2 + +mov x11, x6 + +loop16LineIn: +cmp x11, #4 +blt L8 +sub x11, x11, #4 +ld1 {v0.4s, v1.4s}, [x0], #32 +ld1 {v2.4s, v3.4s}, [x1], #32 + +fadd v4.4s, v0.4s, v2.4s +fadd v5.4s, v1.4s, v3.4s + +ld1 {v6.4s, v7.4s}, [x0], #32 +ld1 {v8.4s, v9.4s}, [x1], #32 + +cmp x11, #4 +blt loop16LineOut + +loop16: +st1 {v4.4s, v5.4s}, [x2], #32 +fadd v10.4s, v6.4s, v8.4s +fadd v11.4s, v7.4s, v9.4s +ld1 {v0.4s, v1.4s}, [x0], #32 +ld1 {v2.4s, v3.4s}, [x1], #32 + +st1 {v10.4s, v11.4s}, [x2], #32 +fadd v4.4s, v0.4s, v2.4s +fadd v5.4s, v1.4s, v3.4s +ld1 {v6.4s, v7.4s}, [x0], #32 +ld1 {v8.4s, v9.4s}, [x1], #32 + +sub x11, x11, #4 +cmp x11, #4 +bge loop16 + +loop16LineOut: +st1 {v4.4s, v5.4s}, [x2], #32 +fadd v10.4s, v6.4s, v8.4s +fadd v11.4s, v7.4s, v9.4s +st1 {v10.4s, v11.4s}, [x2], #32 + + +L8: +cmp x11, #2 +blt L4 +ld1 {v0.4s, v1.4s}, [x0], #32 +ld1 {v2.4s, v3.4s}, [x1], #32 +fadd v4.4s, v0.4s, v2.4s +fadd v5.4s, v1.4s, v3.4s +sub x11, x11, #2 +st1 {v4.4s, v5.4s}, [x2], #32 + + +cmp x11, #0 +beq loop16EndLine + +L4: +ld1 {v0.4s}, [x0], #16 +ld1 {v1.4s}, [x1], #16 +fadd v0.4s, v0.4s, v1.4s +sub x11, x11, #1 +st1 {v0.4s}, [x2], #16 +//bne L4 + +loop16EndLine: +add x0, x8, x3 +add x1, x9, x4 +add x2, x10, x5 + +subs x7, x7, #1 +bne loopH + +ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matrix_sub.S b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matrix_sub.S new file mode 100644 index 00000000000..7ac5f56a391 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matrix_sub.S @@ -0,0 +1,105 @@ + +#ifdef __aarch64__ + .text + .align 5 + //.p2align 5,,15 + .global MatrixSub +#ifndef __APPLE__ + .type MatrixSub, %function +#endif + + + +//void MatrixSub(const float* matDataA, const float* matDataB, float* matDataC, +// size_t aStride, size_t bStride, size_t cStride, size_t width, size_t height) + +//Auto: x0: matDataA, x1:matDataB, x2:matDatac, +//x3:aStride, x4:bStride, x5:cStride, x6:width, x7:height + +MatrixSub: +mov x12, #4 //sizeof(float) +mul x3, x12, x3 +mul x4, x12, x4 +mul x5, x12, x5 + +loopH: +mov x8, x0 +mov x9, x1 +mov x10, x2 + +mov x11, x6 + +loop16LineIn: +cmp x11, #4 +blt L8 +sub x11, x11, #4 +ld1 {v0.4s, v1.4s}, [x0], #32 +ld1 {v2.4s, v3.4s}, [x1], #32 + +fsub v4.4s, v0.4s, v2.4s +fsub v5.4s, v1.4s, v3.4s + +ld1 {v6.4s, v7.4s}, [x0], #32 +ld1 {v8.4s, v9.4s}, [x1], #32 + +cmp x11, #4 +blt loop16LineOut + +loop16: +st1 {v4.4s, v5.4s}, [x2], #32 +fsub v10.4s, v6.4s, v8.4s +fsub v11.4s, v7.4s, v9.4s +ld1 {v0.4s, v1.4s}, [x0], #32 +ld1 {v2.4s, v3.4s}, [x1], #32 + +st1 {v10.4s, v11.4s}, [x2], #32 +fsub v4.4s, v0.4s, v2.4s +fsub v5.4s, v1.4s, v3.4s +ld1 {v6.4s, v7.4s}, [x0], #32 +ld1 {v8.4s, v9.4s}, [x1], #32 + +sub x11, x11, #4 +cmp x11, #4 +bge loop16 + +loop16LineOut: +st1 {v4.4s, v5.4s}, [x2], #32 +fsub v10.4s, v6.4s, v8.4s +fsub v11.4s, v7.4s, v9.4s +st1 {v10.4s, v11.4s}, [x2], #32 + +L8: +cmp x11, #2 +blt L4 + +ld1 {v0.4s, v1.4s}, [x0], #32 +ld1 {v2.4s, v3.4s}, [x1], #32 + +fsub v4.4s, v0.4s, v2.4s +fsub v5.4s, v1.4s, v3.4s + +sub x11, x11, #2 +st1 {v4.4s, v5.4s}, [x2], #32 + + +cmp x11, #0 +beq loop16EndLine + +L4: +ld1 {v0.4s}, [x0], #16 +ld1 {v1.4s}, [x1], #16 +fsub v0.4s, v0.4s, v1.4s +sub x11, x11, #1 +st1 {v0.4s}, [x2], #16 + + +loop16EndLine: +add x0, x8, x3 +add x1, x9, x4 +add x2, x10, x5 + +subs x7, x7, #1 +bne loopH + +ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/relu.S b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/relu.S new file mode 100644 index 00000000000..74c40a135b0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/relu.S @@ -0,0 +1,73 @@ + +#ifdef __aarch64__ + .text + .align 5 + //.p2align 5,,15 + .global Relu +#ifndef __APPLE__ + .type Relu, %function +#endif + + +//void Relu(float* data, size_t element4) + +//Auto: x0:data, x1: element4 + +Relu: +cmp x1, #0 +beq ReluEnd + +dup v16.4s, wzr + +mov x5, x0 + +Loop16LineIn: +cmp x1, #4 +blt L4 +sub x1, x1, #4 + +ld1 {v1.4s, v2.4s}, [x5], #32 + +fmax v5.4s, v16.4s, v1.4s +fmax v6.4s, v16.4s, v2.4s +ld1 {v3.4s, v4.4s}, [x5], #32 + +cmp x1, #4 +blt Loop16LineOut + +Loop16: +st1 {v5.4s, v6.4s}, [x0], #32 +fmax v7.4s, v16.4s, v3.4s +fmax v8.4s, v16.4s, v4.4s +ld1 {v1.4s, v2.4s}, [x5], #32 + +st1 {v7.4s, v8.4s}, [x0], #32 +fmax v5.4s, v16.4s, v1.4s +fmax v6.4s, v16.4s, v2.4s +ld1 {v3.4s, v4.4s}, [x5], #32 + +sub x1, x1, #4 +cmp x1, #4 +bge Loop16 + +Loop16LineOut: +st1 {v5.4s, v6.4s}, [x0], #32 +fmax v7.4s, v16.4s, v3.4s +fmax v8.4s, v16.4s, v4.4s + +st1 {v7.4s, v8.4s}, [x0], #32 + +L4: +cmp x1, #0 +beq ReluEnd +Loop4: +ld1 {v1.4s}, [x5], #16 +fmax v2.4s, v16.4s, v0.4s +subs x1, x1, #1 +st1 {v2.4s}, [x0], #16 +bne Loop4 + +ReluEnd: + +ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/relu6.S b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/relu6.S new file mode 100644 index 00000000000..c1789845eec --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/relu6.S @@ -0,0 +1,89 @@ + +#ifdef __aarch64__ + .text + .align 5 + //.p2align 5,,15 + .global Relu6 +#ifndef __APPLE__ + .type Relu6, %function +#endif + + +//void Relu6(float* data, size_t element4) + +//Auto: x0:data, x1: element4 + +Relu6: +cmp x1, #0 +beq Relu6End + +dup v16.4s, wzr +movi v17.4s, #6 +scvtf v17.4s, v17.4s + +mov x5, x0 + +Loop16LineIn: +cmp x1, #4 +blt L4 +sub x1, x1, #4 + +ld1 {v1.4s, v2.4s}, [x5], #32 + +fmax v21.4s, v1.4s, v16.4s +fmax v22.4s, v2.4s, v16.4s +ld1 {v3.4s, v4.4s}, [x5], #32 + +fmin v23.4s, v21.4s, v17.4s +fmin v24.4s, v22.4s, v17.4s + + +cmp x1, #4 +blt Loop16LineOut + +Loop16: +st1 {v23.4s, v24.4s}, [x0], #32 +fmax v25.4s, v3.4s, v16.4s +fmax v26.4s, v4.4s, v16.4s +ld1 {v1.4s, v2.4s}, [x5], #32 + +fmin v27.4s, v25.4s, v17.4s +fmin v28.4s, v26.4s, v17.4s +fmax v21.4s, v1.4s, v16.4s +fmax v22.4s, v2.4s, v16.4s + +st1 {v27.4s, v28.4s}, [x0], #32 +ld1 {v3.4s, v4.4s}, [x5], #32 +fmin v23.4s, v21.4s, v17.4s +fmin v24.4s, v22.4s, v17.4s + +sub x1, x1, #4 +cmp x1, #4 +bge Loop16 + +Loop16LineOut: +st1 {v23.4s, v24.4s}, [x0], #32 +fmax v25.4s, v3.4s, v16.4s +fmax v26.4s, v4.4s, v16.4s + +fmin v27.4s, v25.4s, v17.4s +fmin v28.4s, v26.4s, v17.4s +st1 {v27.4s, v28.4s}, [x0], #32 + +L4: +cmp x1, #0 +beq Relu6End +Loop4: +ld1 {v1.4s}, [x5], #16 +fmax v1.4s, v1.4s, v16.4s + +fmin v1.4s, v1.4s, v17.4s + +subs x1, x1, #1 +st1 {v1.4s}, [x0], #16 +bne Loop4 + +Relu6End: + +ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/opt/IndirectGemmInt8_24x4_dp.S b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/opt/IndirectGemmInt8_24x4_dp.S new file mode 100644 index 00000000000..278b4376b2e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/opt/IndirectGemmInt8_24x4_dp.S @@ -0,0 +1,636 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global IndirectGemmInt8_24x4_dp +#ifndef __APPLE__ +.type IndirectGemmInt8_24x4_dp, %function +#endif + +// void IndirectGemmInt8_24x4_dp(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, +// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, +// size_t shift_before, size_t shift_after); +// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset +// we use sdot intrinsic on cores that supports dotprod(Armv8.2-A w/dp or later) +// mrs intrinsic could read system register ID_AA64ISAR0_EL1(or s3_0_c0_c6_0 on Armv8.2-A) +// the 44-48 bits indicates whether dotprod is supported +IndirectGemmInt8_24x4_dp: + + .macro INIT_BIAS + mov x20, x15 + ld1r {v8.4s}, [x20], #4 + ld1r {v9.4s}, [x20], #4 + ld1r {v10.4s}, [x20], #4 + ld1r {v11.4s}, [x20], #4 + ld1r {v12.4s}, [x20], #4 + ld1r {v13.4s}, [x20], #4 + ld1r {v14.4s}, [x20], #4 + ld1r {v15.4s}, [x20], #4 + ld1r {v16.4s}, [x20], #4 + ld1r {v17.4s}, [x20], #4 + ld1r {v18.4s}, [x20], #4 + ld1r {v19.4s}, [x20], #4 + ld1r {v20.4s}, [x20], #4 + ld1r {v21.4s}, [x20], #4 + ld1r {v22.4s}, [x20], #4 + ld1r {v23.4s}, [x20], #4 + ld1r {v24.4s}, [x20], #4 + ld1r {v25.4s}, [x20], #4 + ld1r {v26.4s}, [x20], #4 + ld1r {v27.4s}, [x20], #4 + ld1r {v28.4s}, [x20], #4 + ld1r {v29.4s}, [x20], #4 + ld1r {v30.4s}, [x20], #4 + ld1r {v31.4s}, [x20], #4 + dup v7.4s, wzr + cbz x3, InitBias + ld1 {v7.4s}, [x3] + InitBias: + sub v8.4s, v7.4s, v8.4s + sub v9.4s, v7.4s, v9.4s + sub v10.4s, v7.4s, v10.4s + sub v11.4s, v7.4s, v11.4s + sub v12.4s, v7.4s, v12.4s + sub v13.4s, v7.4s, v13.4s + sub v14.4s, v7.4s, v14.4s + sub v15.4s, v7.4s, v15.4s + sub v16.4s, v7.4s, v16.4s + sub v17.4s, v7.4s, v17.4s + sub v18.4s, v7.4s, v18.4s + sub v19.4s, v7.4s, v19.4s + sub v20.4s, v7.4s, v20.4s + sub v21.4s, v7.4s, v21.4s + sub v22.4s, v7.4s, v22.4s + sub v23.4s, v7.4s, v23.4s + sub v24.4s, v7.4s, v24.4s + sub v25.4s, v7.4s, v25.4s + sub v26.4s, v7.4s, v26.4s + sub v27.4s, v7.4s, v27.4s + sub v28.4s, v7.4s, v28.4s + sub v29.4s, v7.4s, v29.4s + sub v30.4s, v7.4s, v30.4s + sub v31.4s, v7.4s, v31.4s + .endm + + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // r19 ~ r29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + + ldr x15, [sp] + ldr w8, [sp, #8] + ldr w9, [sp, #16] + ldr w16, [sp, #24] + ldr w17, [sp, #32] + ldr w18, [sp, #40] + ldr w19, [sp, #48] + + mul x5, x4, x5 + mov x4, #1 + + LoopOc: + + mov x10, x4 + mov x12, x1 + + LoopKsize: + INIT_BIAS + mov x11, x0 + + // as some processors do not support sdot intrinsic, we use instruction word + // dp support is stilled judged dymaticly, instruction word is just used to ensure compilation + // according to https://static.docs.arm.com/ddi0596/g/ISA_A64_xml_v86A-2020-03_OPT.pdf + // the instruction word of sdot vd.4s, vn.16b, vm.4b[index] is + // 0100 1111 10Lm mmmm 1110 H0nn nnnd dddd + // mmmmm/nnnnn/ddddd is the number of neon register, HL is the high/low bit of index + + // load input for output 1-8 + ld1 {v0.16b, v1.16b}, [x12], #32 + // load weight + ld1 {v6.16b}, [x2], #16 + // step for output 1-4 + .inst 0x4f80e0c8 // sdot v8.4s, v6.16b, v0.4b[0] + .inst 0x4fa0e0c9 // sdot v9.4s, v6.16b, v0.4b[1] + .inst 0x4f80e8ca // sdot v10.4s, v6.16b, v0.4b[2] + .inst 0x4fa0e8cb // sdot v11.4s, v6.16b, v0.4b[3] + // load input for output 9-16 + ld1 {v2.16b, v3.16b, v4.16b, v5.16b}, [x12], #64 + // another step for output 5-8 + .inst 0x4f81e0cc // sdot v12.4s, v6.16b, v1.4b[0] + .inst 0x4fa1e0cd // sdot v13.4s, v6.16b, v1.4b[1] + .inst 0x4f81e8ce // sdot v14.4s, v6.16b, v1.4b[2] + .inst 0x4fa1e8cf // sdot v15.4s, v6.16b, v1.4b[3] + + subs x13, x5, #1 + beq LoopIcEndOne + // load weight + ld1 {v7.16b}, [x2], #16 + cmp x13, #1 + beq LoopIcEnd + + LoopIc: + // load input for output 1-8 + ld1 {v0.16b, v1.16b}, [x12], #32 + .inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0] + .inst 0x4fa2e0d1 // sdot v17.4s, v6.16b, v2.4b[1] + .inst 0x4f82e8d2 // sdot v18.4s, v6.16b, v2.4b[2] + .inst 0x4fa2e8d3 // sdot v19.4s, v6.16b, v2.4b[3] + .inst 0x4f83e0d4 // sdot v20.4s, v6.16b, v3.4b[0] + .inst 0x4fa3e0d5 // sdot v21.4s, v6.16b, v3.4b[1] + .inst 0x4f83e8d6 // sdot v22.4s, v6.16b, v3.4b[2] + .inst 0x4fa3e8d7 // sdot v23.4s, v6.16b, v3.4b[3] + ld1 {v2.16b, v3.16b}, [x12], #32 + .inst 0x4f84e0d8 // sdot v24.4s, v6.16b, v4.4b[0] + .inst 0x4fa4e0d9 // sdot v25.4s, v6.16b, v4.4b[1] + .inst 0x4f84e8da // sdot v26.4s, v6.16b, v4.4b[2] + .inst 0x4fa4e8db // sdot v27.4s, v6.16b, v4.4b[3] + .inst 0x4f85e0dc // sdot v28.4s, v6.16b, v5.4b[0] + .inst 0x4fa5e0dd // sdot v29.4s, v6.16b, v5.4b[1] + .inst 0x4f85e8de // sdot v30.4s, v6.16b, v5.4b[2] + .inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3] + // load input for output 9-16 + ld1 {v4.4s, v5.4s}, [x12], #32 + .inst 0x4f80e0e8 // sdot v8.4s, v7.16b, v0.4b[0] + .inst 0x4fa0e0e9 // sdot v9.4s, v7.16b, v0.4b[1] + .inst 0x4f80e8ea // sdot v10.4s, v7.16b, v0.4b[2] + .inst 0x4fa0e8eb // sdot v11.4s, v7.16b, v0.4b[3] + // another step for output 5-8 + .inst 0x4f81e0ec // sdot v12.4s, v7.16b, v1.4b[0] + .inst 0x4fa1e0ed // sdot v13.4s, v7.16b, v1.4b[1] + .inst 0x4f81e8ee // sdot v14.4s, v7.16b, v1.4b[2] + .inst 0x4fa1e8ef // sdot v15.4s, v7.16b, v1.4b[3] + // load input for output 1-8 + ld1 {v0.16b, v1.16b}, [x12], #32 + .inst 0x4f82e0f0 // sdot v16.4s, v7.16b, v2.4b[0] + .inst 0x4fa2e0f1 // sdot v17.4s, v7.16b, v2.4b[1] + .inst 0x4f82e8f2 // sdot v18.4s, v7.16b, v2.4b[2] + .inst 0x4fa2e8f3 // sdot v19.4s, v7.16b, v2.4b[3] + .inst 0x4f83e0f4 // sdot v20.4s, v7.16b, v3.4b[0] + .inst 0x4fa3e0f5 // sdot v21.4s, v7.16b, v3.4b[1] + .inst 0x4f83e8f6 // sdot v22.4s, v7.16b, v3.4b[2] + .inst 0x4fa3e8f7 // sdot v23.4s, v7.16b, v3.4b[3] + // load weight + ld1 {v6.16b}, [x2], #16 + .inst 0x4f84e0f8 // sdot v24.4s, v7.16b, v4.4b[0] + .inst 0x4fa4e0f9 // sdot v25.4s, v7.16b, v4.4b[1] + .inst 0x4f84e8fa // sdot v26.4s, v7.16b, v4.4b[2] + .inst 0x4fa4e8fb // sdot v27.4s, v7.16b, v4.4b[3] + .inst 0x4f85e0fc // sdot v28.4s, v7.16b, v5.4b[0] + .inst 0x4fa5e0fd // sdot v29.4s, v7.16b, v5.4b[1] + .inst 0x4f85e8fe // sdot v30.4s, v7.16b, v5.4b[2] + .inst 0x4fa5e8ff // sdot v31.4s, v7.16b, v5.4b[3] + // load input for output 9-16 + ld1 {v2.4s, v3.4s}, [x12], #32 + .inst 0x4f80e0c8 // sdot v8.4s, v6.16b, v0.4b[0] + .inst 0x4fa0e0c9 // sdot v9.4s, v6.16b, v0.4b[1] + .inst 0x4f80e8ca // sdot v10.4s, v6.16b, v0.4b[2] + .inst 0x4fa0e8cb // sdot v11.4s, v6.16b, v0.4b[3] + // another step for output 5-8 + .inst 0x4f81e0cc // sdot v12.4s, v6.16b, v1.4b[0] + .inst 0x4fa1e0cd // sdot v13.4s, v6.16b, v1.4b[1] + .inst 0x4f81e8ce // sdot v14.4s, v6.16b, v1.4b[2] + .inst 0x4fa1e8cf // sdot v15.4s, v6.16b, v1.4b[3] + // load input for output 9-16 + ld1 {v4.4s, v5.4s}, [x12], #32 + + subs x13, x13, #2 + beq LoopIcEndOne + // load weight + ld1 {v7.16b}, [x2], #16 + cmp x13, #1 + beq LoopIcEnd + b LoopIc + + LoopIcEnd: + mov x20, x15 + // load input for output 1-8 + ld1 {v0.16b, v1.16b}, [x12], #32 + .inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0] + .inst 0x4fa2e0d1 // sdot v17.4s, v6.16b, v2.4b[1] + .inst 0x4f82e8d2 // sdot v18.4s, v6.16b, v2.4b[2] + .inst 0x4fa2e8d3 // sdot v19.4s, v6.16b, v2.4b[3] + .inst 0x4f83e0d4 // sdot v20.4s, v6.16b, v3.4b[0] + .inst 0x4fa3e0d5 // sdot v21.4s, v6.16b, v3.4b[1] + .inst 0x4f83e8d6 // sdot v22.4s, v6.16b, v3.4b[2] + .inst 0x4fa3e8d7 // sdot v23.4s, v6.16b, v3.4b[3] + ld1 {v2.16b, v3.16b}, [x12], #32 + .inst 0x4f84e0d8 // sdot v24.4s, v6.16b, v4.4b[0] + .inst 0x4fa4e0d9 // sdot v25.4s, v6.16b, v4.4b[1] + .inst 0x4f84e8da // sdot v26.4s, v6.16b, v4.4b[2] + .inst 0x4fa4e8db // sdot v27.4s, v6.16b, v4.4b[3] + .inst 0x4f85e0dc // sdot v28.4s, v6.16b, v5.4b[0] + .inst 0x4fa5e0dd // sdot v29.4s, v6.16b, v5.4b[1] + .inst 0x4f85e8de // sdot v30.4s, v6.16b, v5.4b[2] + .inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3] + // load input for output 9-16 + ld1 {v4.4s, v5.4s}, [x12], #32 + .inst 0x4f80e0e8 // sdot v8.4s, v7.16b, v0.4b[0] + .inst 0x4fa0e0e9 // sdot v9.4s, v7.16b, v0.4b[1] + .inst 0x4f80e8ea // sdot v10.4s, v7.16b, v0.4b[2] + .inst 0x4fa0e8eb // sdot v11.4s, v7.16b, v0.4b[3] + .inst 0x4f81e0ec // sdot v12.4s, v7.16b, v1.4b[0] + .inst 0x4fa1e0ed // sdot v13.4s, v7.16b, v1.4b[1] + .inst 0x4f81e8ee // sdot v14.4s, v7.16b, v1.4b[2] + .inst 0x4fa1e8ef // sdot v15.4s, v7.16b, v1.4b[3] + + .inst 0x4f82e0f0 // sdot v16.4s, v7.16b, v2.4b[0] + .inst 0x4fa2e0f1 // sdot v17.4s, v7.16b, v2.4b[1] + .inst 0x4f82e8f2 // sdot v18.4s, v7.16b, v2.4b[2] + .inst 0x4fa2e8f3 // sdot v19.4s, v7.16b, v2.4b[3] + .inst 0x4f83e0f4 // sdot v20.4s, v7.16b, v3.4b[0] + .inst 0x4fa3e0f5 // sdot v21.4s, v7.16b, v3.4b[1] + .inst 0x4f83e8f6 // sdot v22.4s, v7.16b, v3.4b[2] + .inst 0x4fa3e8f7 // sdot v23.4s, v7.16b, v3.4b[3] + + .inst 0x4f84e0f8 // sdot v24.4s, v7.16b, v4.4b[0] + .inst 0x4fa4e0f9 // sdot v25.4s, v7.16b, v4.4b[1] + .inst 0x4f84e8fa // sdot v26.4s, v7.16b, v4.4b[2] + .inst 0x4fa4e8fb // sdot v27.4s, v7.16b, v4.4b[3] + .inst 0x4f85e0fc // sdot v28.4s, v7.16b, v5.4b[0] + .inst 0x4fa5e0fd // sdot v29.4s, v7.16b, v5.4b[1] + .inst 0x4f85e8fe // sdot v30.4s, v7.16b, v5.4b[2] + .inst 0x4fa5e8ff // sdot v31.4s, v7.16b, v5.4b[3] + b Quantization + + LoopIcEndOne: + .inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0] + .inst 0x4fa2e0d1 // sdot v17.4s, v6.16b, v2.4b[1] + .inst 0x4f82e8d2 // sdot v18.4s, v6.16b, v2.4b[2] + .inst 0x4fa2e8d3 // sdot v19.4s, v6.16b, v2.4b[3] + .inst 0x4f83e0d4 // sdot v20.4s, v6.16b, v3.4b[0] + .inst 0x4fa3e0d5 // sdot v21.4s, v6.16b, v3.4b[1] + .inst 0x4f83e8d6 // sdot v22.4s, v6.16b, v3.4b[2] + .inst 0x4fa3e8d7 // sdot v23.4s, v6.16b, v3.4b[3] + + .inst 0x4f84e0d8 // sdot v24.4s, v6.16b, v4.4b[0] + .inst 0x4fa4e0d9 // sdot v25.4s, v6.16b, v4.4b[1] + .inst 0x4f84e8da // sdot v26.4s, v6.16b, v4.4b[2] + .inst 0x4fa4e8db // sdot v27.4s, v6.16b, v4.4b[3] + .inst 0x4f85e0dc // sdot v28.4s, v6.16b, v5.4b[0] + .inst 0x4fa5e0dd // sdot v29.4s, v6.16b, v5.4b[1] + .inst 0x4f85e8de // sdot v30.4s, v6.16b, v5.4b[2] + .inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3] + + Quantization: + dup v2.4s, w18 + sqshl v8.4s, v8.4s ,v2.4s + sqshl v9.4s, v9.4s ,v2.4s + sqshl v10.4s, v10.4s ,v2.4s + sqshl v11.4s, v11.4s ,v2.4s + sqshl v12.4s, v12.4s ,v2.4s + sqshl v13.4s, v13.4s ,v2.4s + sqshl v14.4s, v14.4s ,v2.4s + sqshl v15.4s, v15.4s ,v2.4s + sqshl v16.4s, v16.4s ,v2.4s + sqshl v17.4s, v17.4s ,v2.4s + sqshl v18.4s, v18.4s ,v2.4s + sqshl v19.4s, v19.4s ,v2.4s + sqshl v20.4s, v20.4s ,v2.4s + sqshl v21.4s, v21.4s ,v2.4s + sqshl v22.4s, v22.4s ,v2.4s + sqshl v23.4s, v23.4s ,v2.4s + sqshl v24.4s, v24.4s ,v2.4s + sqshl v25.4s, v25.4s ,v2.4s + sqshl v26.4s, v26.4s ,v2.4s + sqshl v27.4s, v27.4s ,v2.4s + sqshl v28.4s, v28.4s ,v2.4s + sqshl v29.4s, v29.4s ,v2.4s + sqshl v30.4s, v30.4s ,v2.4s + sqshl v31.4s, v31.4s ,v2.4s + + dup v3.4s, w17 + sqrdmulh v8.4s, v8.4s ,v3.4s + sqrdmulh v9.4s, v9.4s ,v3.4s + sqrdmulh v10.4s, v10.4s ,v3.4s + sqrdmulh v11.4s, v11.4s ,v3.4s + sqrdmulh v12.4s, v12.4s ,v3.4s + sqrdmulh v13.4s, v13.4s ,v3.4s + sqrdmulh v14.4s, v14.4s ,v3.4s + sqrdmulh v15.4s, v15.4s ,v3.4s + sqrdmulh v16.4s, v16.4s ,v3.4s + sqrdmulh v17.4s, v17.4s ,v3.4s + sqrdmulh v18.4s, v18.4s ,v3.4s + sqrdmulh v19.4s, v19.4s ,v3.4s + sqrdmulh v20.4s, v20.4s ,v3.4s + sqrdmulh v21.4s, v21.4s ,v3.4s + sqrdmulh v22.4s, v22.4s ,v3.4s + sqrdmulh v23.4s, v23.4s ,v3.4s + sqrdmulh v24.4s, v24.4s ,v3.4s + sqrdmulh v25.4s, v25.4s ,v3.4s + sqrdmulh v26.4s, v26.4s ,v3.4s + sqrdmulh v27.4s, v27.4s ,v3.4s + sqrdmulh v28.4s, v28.4s ,v3.4s + sqrdmulh v29.4s, v29.4s ,v3.4s + sqrdmulh v30.4s, v30.4s ,v3.4s + sqrdmulh v31.4s, v31.4s ,v3.4s + + dup v4.4s, w19 + sqrshl v8.4s, v8.4s ,v4.4s + sqrshl v9.4s, v9.4s ,v4.4s + sqrshl v10.4s, v10.4s ,v4.4s + sqrshl v11.4s, v11.4s ,v4.4s + sqrshl v12.4s, v12.4s ,v4.4s + sqrshl v13.4s, v13.4s ,v4.4s + sqrshl v14.4s, v14.4s ,v4.4s + sqrshl v15.4s, v15.4s ,v4.4s + sqrshl v16.4s, v16.4s ,v4.4s + sqrshl v17.4s, v17.4s ,v4.4s + sqrshl v18.4s, v18.4s ,v4.4s + sqrshl v19.4s, v19.4s ,v4.4s + sqrshl v20.4s, v20.4s ,v4.4s + sqrshl v21.4s, v21.4s ,v4.4s + sqrshl v22.4s, v22.4s ,v4.4s + sqrshl v23.4s, v23.4s ,v4.4s + sqrshl v24.4s, v24.4s ,v4.4s + sqrshl v25.4s, v25.4s ,v4.4s + sqrshl v26.4s, v26.4s ,v4.4s + sqrshl v27.4s, v27.4s ,v4.4s + sqrshl v28.4s, v28.4s ,v4.4s + sqrshl v29.4s, v29.4s ,v4.4s + sqrshl v30.4s, v30.4s ,v4.4s + sqrshl v31.4s, v31.4s ,v4.4s + + dup v5.4s, w16 + add v8.4s, v8.4s ,v5.4s + add v9.4s, v9.4s ,v5.4s + add v10.4s, v10.4s ,v5.4s + add v11.4s, v11.4s ,v5.4s + add v12.4s, v12.4s ,v5.4s + add v13.4s, v13.4s ,v5.4s + add v14.4s, v14.4s ,v5.4s + add v15.4s, v15.4s ,v5.4s + add v16.4s, v16.4s ,v5.4s + add v17.4s, v17.4s ,v5.4s + add v18.4s, v18.4s ,v5.4s + add v19.4s, v19.4s ,v5.4s + add v20.4s, v20.4s ,v5.4s + add v21.4s, v21.4s ,v5.4s + add v22.4s, v22.4s ,v5.4s + add v23.4s, v23.4s ,v5.4s + add v24.4s, v24.4s ,v5.4s + add v25.4s, v25.4s ,v5.4s + add v26.4s, v26.4s ,v5.4s + add v27.4s, v27.4s ,v5.4s + add v28.4s, v28.4s ,v5.4s + add v29.4s, v29.4s ,v5.4s + add v30.4s, v30.4s ,v5.4s + add v31.4s, v31.4s ,v5.4s + + dup v0.4s, w8 + smax v8.4s, v8.4s ,v0.4s + smax v9.4s, v9.4s ,v0.4s + smax v10.4s, v10.4s ,v0.4s + smax v11.4s, v11.4s ,v0.4s + smax v12.4s, v12.4s ,v0.4s + smax v13.4s, v13.4s ,v0.4s + smax v14.4s, v14.4s ,v0.4s + smax v15.4s, v15.4s ,v0.4s + smax v16.4s, v16.4s ,v0.4s + smax v17.4s, v17.4s ,v0.4s + smax v18.4s, v18.4s ,v0.4s + smax v19.4s, v19.4s ,v0.4s + smax v20.4s, v20.4s ,v0.4s + smax v21.4s, v21.4s ,v0.4s + smax v22.4s, v22.4s ,v0.4s + smax v23.4s, v23.4s ,v0.4s + smax v24.4s, v24.4s ,v0.4s + smax v25.4s, v25.4s ,v0.4s + smax v26.4s, v26.4s ,v0.4s + smax v27.4s, v27.4s ,v0.4s + smax v28.4s, v28.4s ,v0.4s + smax v29.4s, v29.4s ,v0.4s + smax v30.4s, v30.4s ,v0.4s + smax v31.4s, v31.4s ,v0.4s + + dup v1.4s, w9 + smin v8.4s, v8.4s ,v1.4s + smin v9.4s, v9.4s ,v1.4s + smin v10.4s, v10.4s ,v1.4s + smin v11.4s, v11.4s ,v1.4s + smin v12.4s, v12.4s ,v1.4s + smin v13.4s, v13.4s ,v1.4s + smin v14.4s, v14.4s ,v1.4s + smin v15.4s, v15.4s ,v1.4s + smin v16.4s, v16.4s ,v1.4s + smin v17.4s, v17.4s ,v1.4s + smin v18.4s, v18.4s ,v1.4s + smin v19.4s, v19.4s ,v1.4s + smin v20.4s, v20.4s ,v1.4s + smin v21.4s, v21.4s ,v1.4s + smin v22.4s, v22.4s ,v1.4s + smin v23.4s, v23.4s ,v1.4s + smin v24.4s, v24.4s ,v1.4s + smin v25.4s, v25.4s ,v1.4s + smin v26.4s, v26.4s ,v1.4s + smin v27.4s, v27.4s ,v1.4s + smin v28.4s, v28.4s ,v1.4s + smin v29.4s, v29.4s ,v1.4s + smin v30.4s, v30.4s ,v1.4s + smin v31.4s, v31.4s ,v1.4s + + sqxtn v6.4h, v8.4s + sqxtn2 v6.8h, v9.4s + sqxtn v0.8b, v6.8h + sqxtn v7.4h, v10.4s + sqxtn2 v7.8h, v11.4s + sqxtn2 v0.16b, v7.8h + + sqxtn v6.4h, v12.4s + sqxtn2 v6.8h, v13.4s + sqxtn v1.8b, v6.8h + sqxtn v7.4h, v14.4s + sqxtn2 v7.8h, v15.4s + sqxtn2 v1.16b, v7.8h + + sqxtn v6.4h, v16.4s + sqxtn2 v6.8h, v17.4s + sqxtn v2.8b, v6.8h + sqxtn v7.4h, v18.4s + sqxtn2 v7.8h, v19.4s + sqxtn2 v2.16b, v7.8h + + sqxtn v6.4h, v20.4s + sqxtn2 v6.8h, v21.4s + sqxtn v3.8b, v6.8h + sqxtn v7.4h, v22.4s + sqxtn2 v7.8h, v23.4s + sqxtn2 v3.16b, v7.8h + + sqxtn v6.4h, v24.4s + sqxtn2 v6.8h, v25.4s + sqxtn v4.8b, v6.8h + sqxtn v7.4h, v26.4s + sqxtn2 v7.8h, v27.4s + sqxtn2 v4.16b, v7.8h + + sqxtn v6.4h, v28.4s + sqxtn2 v6.8h, v29.4s + sqxtn v5.8b, v6.8h + sqxtn v7.4h, v30.4s + sqxtn2 v7.8h, v31.4s + sqxtn2 v5.16b, v7.8h + // prefetching is not prefered while writing results in spite of cache missings + // you could try prfm pstl2strm + WriteStart: + cmp x6, #1 + beq Write1 + cmp x6, #2 + beq Write2 + cmp x6, #3 + beq Write3 + b Write4 + Write1: + st1 {v0.b}[0], [x11], x7 + st1 {v0.b}[4], [x11], x7 + st1 {v0.b}[8], [x11], x7 + st1 {v0.b}[12], [x11], x7 + st1 {v1.b}[0], [x11], x7 + st1 {v1.b}[4], [x11], x7 + st1 {v1.b}[8], [x11], x7 + st1 {v1.b}[12], [x11], x7 + st1 {v2.b}[0], [x11], x7 + st1 {v2.b}[4], [x11], x7 + st1 {v2.b}[8], [x11], x7 + st1 {v2.b}[12], [x11], x7 + st1 {v3.b}[0], [x11], x7 + st1 {v3.b}[4], [x11], x7 + st1 {v3.b}[8], [x11], x7 + st1 {v3.b}[12], [x11], x7 + st1 {v4.b}[0], [x11], x7 + st1 {v4.b}[4], [x11], x7 + st1 {v4.b}[8], [x11], x7 + st1 {v4.b}[12], [x11], x7 + st1 {v5.b}[0], [x11], x7 + st1 {v5.b}[4], [x11], x7 + st1 {v5.b}[8], [x11], x7 + st1 {v5.b}[12], [x11] + add x0, x0, #1 + b WriteEnd + Write2: + st1 {v0.h}[0], [x11], x7 + st1 {v0.h}[2], [x11], x7 + st1 {v0.h}[4], [x11], x7 + st1 {v0.h}[6], [x11], x7 + st1 {v1.h}[0], [x11], x7 + st1 {v1.h}[2], [x11], x7 + st1 {v1.h}[4], [x11], x7 + st1 {v1.h}[6], [x11], x7 + st1 {v2.h}[0], [x11], x7 + st1 {v2.h}[2], [x11], x7 + st1 {v2.h}[4], [x11], x7 + st1 {v2.h}[6], [x11], x7 + st1 {v3.h}[0], [x11], x7 + st1 {v3.h}[2], [x11], x7 + st1 {v3.h}[4], [x11], x7 + st1 {v3.h}[6], [x11], x7 + st1 {v4.h}[0], [x11], x7 + st1 {v4.h}[2], [x11], x7 + st1 {v4.h}[4], [x11], x7 + st1 {v4.h}[6], [x11], x7 + st1 {v5.h}[0], [x11], x7 + st1 {v5.h}[2], [x11], x7 + st1 {v5.h}[4], [x11], x7 + st1 {v5.h}[6], [x11] + add x0, x0, #2 + b WriteEnd + Write3: + add x14, x11, #2 + st1 {v0.h}[0], [x11], x7 + st1 {v0.b}[2], [x14], x7 + st1 {v0.h}[2], [x11], x7 + st1 {v0.b}[6], [x14], x7 + st1 {v0.h}[4], [x11], x7 + st1 {v0.b}[10], [x14], x7 + st1 {v0.h}[6], [x11], x7 + st1 {v0.b}[14], [x14], x7 + st1 {v1.h}[0], [x11], x7 + st1 {v1.b}[2], [x14], x7 + st1 {v1.h}[2], [x11], x7 + st1 {v1.b}[6], [x14], x7 + st1 {v1.h}[4], [x11], x7 + st1 {v1.b}[10], [x14], x7 + st1 {v1.h}[6], [x11], x7 + st1 {v1.b}[14], [x14], x7 + st1 {v2.h}[0], [x11], x7 + st1 {v2.b}[2], [x14], x7 + st1 {v2.h}[2], [x11], x7 + st1 {v2.b}[6], [x14], x7 + st1 {v2.h}[4], [x11], x7 + st1 {v2.b}[10], [x14], x7 + st1 {v2.h}[6], [x11], x7 + st1 {v2.b}[14], [x14], x7 + st1 {v3.h}[0], [x11], x7 + st1 {v3.b}[2], [x14], x7 + st1 {v3.h}[2], [x11], x7 + st1 {v3.b}[6], [x14], x7 + st1 {v3.h}[4], [x11], x7 + st1 {v3.b}[10], [x14], x7 + st1 {v3.h}[6], [x11], x7 + st1 {v3.b}[14], [x14], x7 + st1 {v4.h}[0], [x11], x7 + st1 {v4.b}[2], [x14], x7 + st1 {v4.h}[2], [x11], x7 + st1 {v4.b}[6], [x14], x7 + st1 {v4.h}[4], [x11], x7 + st1 {v4.b}[10], [x14], x7 + st1 {v4.h}[6], [x11], x7 + st1 {v4.b}[14], [x14], x7 + st1 {v5.h}[0], [x11], x7 + st1 {v5.b}[2], [x14], x7 + st1 {v5.h}[2], [x11], x7 + st1 {v5.b}[6], [x14], x7 + st1 {v5.h}[4], [x11], x7 + st1 {v5.b}[10], [x14], x7 + st1 {v5.h}[6], [x11], x7 + st1 {v5.b}[14], [x14], x7 + add x0, x0, #3 + b WriteEnd + Write4: + st1 {v0.s}[0], [x11], x7 + st1 {v0.s}[1], [x11], x7 + st1 {v0.s}[2], [x11], x7 + st1 {v0.s}[3], [x11], x7 + st1 {v1.s}[0], [x11], x7 + st1 {v1.s}[1], [x11], x7 + st1 {v1.s}[2], [x11], x7 + st1 {v1.s}[3], [x11], x7 + st1 {v2.s}[0], [x11], x7 + st1 {v2.s}[1], [x11], x7 + st1 {v2.s}[2], [x11], x7 + st1 {v2.s}[3], [x11], x7 + st1 {v3.s}[0], [x11], x7 + st1 {v3.s}[1], [x11], x7 + st1 {v3.s}[2], [x11], x7 + st1 {v3.s}[3], [x11], x7 + st1 {v4.s}[0], [x11], x7 + st1 {v4.s}[1], [x11], x7 + st1 {v4.s}[2], [x11], x7 + st1 {v4.s}[3], [x11], x7 + st1 {v5.s}[0], [x11], x7 + st1 {v5.s}[1], [x11], x7 + st1 {v5.s}[2], [x11], x7 + st1 {v5.s}[3], [x11] + add x0, x0, #4 + + WriteEnd: + + subs x10, x10, #1 + bne LoopKsize + + subs x6, x6, #4 + cbz x3, NoStepFowrard + add x3, x3, #16 + NoStepFowrard: + bgt LoopOc + + sub sp, sp, #144 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/common_func.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/common_func.cc new file mode 100644 index 00000000000..810cf6f7645 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/common_func.cc @@ -0,0 +1,169 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/common_func.h" +#include "src/runtime/kernel/arm/opclib/quantization/fixed_point.h" + +#ifndef ENABLE_ARM +void IndirectGemmFp32(float *output, const float *input, const float *weight, const float *bias, size_t step, int ic4, + int output_channel, size_t offset, size_t relu, size_t relu6) { + for (int i = 0; i < TILE_NUM; i++) { + int input_tile_offset = i * C4NUM; + int output_tile_offset = i * output_channel; + for (int j = 0; j < output_channel; j++) { + int oc8_block = j / C8NUM; + int oc8_res = j % C8NUM; + int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res; + int out_oc_offset = output_tile_offset + j; + + float acc = 0; + for (int n = 0; n < step; n++) { + int input_kw_offset = input_tile_offset + n * ic4 * C4NUM * TILE_NUM; + int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C8NUM; + + for (int k = 0; k < ic4; k++) { + int input_ic4_offset = input_kw_offset + k * TILE_NUM * C4NUM; + int weight_ic4_offset = weight_kw_offset + k * C4NUM * C8NUM; + for (int m = 0; m < C4NUM; m++) { + int input_ic_offset = input_ic4_offset + m; + int weight_ic_offset = weight_ic4_offset + m * C8NUM; + acc += (weight + weight_ic_offset)[0] * (input + input_ic_offset)[0]; + } + } + } + acc += bias[j]; + if (relu) { + acc = acc > 0 ? acc : 0; + } else if (relu6) { + if (acc < 0) { + acc = 0; + } else if (acc > 6) { + acc = 6; + } else { + } + } + (output + out_oc_offset)[0] = acc; + } + } +} + +void IndirectGemmFp32_8x8(float *output, const float *input, const float *weight, const float *bias, size_t step, + size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4, size_t relu, + size_t relu6) { + int oc4 = UP_DIV(output_channel, C4NUM); + if (mode && writeC4) { + for (int i = 0; i < TILE_NUM; i++) { + int input_tile_offset = i * C4NUM; + int output_tile_offset = i * oc4 * C4NUM * step; + for (int j = 0; j < output_channel; j++) { + int oc4_block = j / 4; + int oc4_res = j % 4; + int oc8_block = oc4_block / 2; + int oc8_res = oc4_block % 2; + int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res * C4NUM + oc4_res; + int out_oc_offset = output_tile_offset + oc4_block * step * C4NUM + oc4_res; + + for (int n = 0; n < step; n++) { + int input_kw_offset = input_tile_offset + n * ic4 * C4NUM * TILE_NUM; + int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C8NUM; + int output_kw_offset = out_oc_offset + n * C4NUM; + float acc = 0; + + for (int k = 0; k < ic4; k++) { + int input_ic4_offset = input_kw_offset + k * TILE_NUM * C4NUM; + int weight_ic4_offset = weight_kw_offset + k * C4NUM * C8NUM; + for (int m = 0; m < 4; m++) { + int input_ic_offset = input_ic4_offset + m; + int weight_ic_offset = weight_ic4_offset + m * C8NUM; + acc += (weight + weight_ic_offset)[0] * (input + input_ic_offset)[0]; + } + } + (output + output_kw_offset)[0] = acc; + } + } + } + } else if (mode) { + IndirectGemmFp32_Comm(output, input, weight, ic4, C8NUM, output_channel, offset); + } else { + IndirectGemmFp32(output, input, weight, bias, step, ic4, output_channel, offset, relu, relu6); + } +} +#endif + +int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); } + +int8_t MaxInt8(int8_t a, int8_t b) { return a ^ ((a ^ b) & -(a < b)); } + +void ReluFp32(float *data, int ele_num) { + for (int i = 0; i < ele_num; i++) { + if (data[i] < 0) { + data[i] = 0; + } else { + // do nothing + } + } +} + +void Relu6Fp32(float *data, int ele_num) { + for (int i = 0; i < ele_num; i++) { + if (data[i] < 0) { + data[i] = 0; + } else if (data[i] > 6) { + data[i] = 6; + } else { + // do nothing + } + } +} + +void IndirectGemmFp32_Comm(float *output, const float *input, const float *weight, size_t ic4, size_t hw, size_t oc, + size_t offset) { + for (int r = 0; r < hw; r++) { + for (int c = 0; c < oc; c++) { + float value = 0; + for (int deep = 0; deep < ic4; deep++) { + int d4mod = deep % 4; + int d4div = deep / 4; + int a_index = d4div * 4 * 8 + r * 4 + d4mod; + int b_index = 8 * deep + c; + value += input[a_index] * weight[b_index]; + } + output[r * offset + c] = value; + } + } + return; +} + +void PostFuncInt8(const int *in, const int *bias, int8_t *out, int oc, int plane, int plane8, int32_t multiplier, + int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi) { + /* (int32_t)row8x8-major * multiplier + bias => (int8)relu => (int8_t)row-major */ + for (int r = 0; r < plane; r++) { + for (int c = 0; c < oc; c++) { + int c8div = c / 8, c8mod = c % 8; + int src_index = c8div * plane8 * 8 + r * 8 + c8mod; + int dst_index = r * oc + c; + int32_t value = in[src_index]; + if (bias != nullptr) { + value = in[src_index] + bias[c]; + } + value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + out[dst_index] = (int8_t)value; + } + } + return; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/common_func.h b/mindspore/lite/src/runtime/kernel/arm/opclib/common_func.h new file mode 100644 index 00000000000..83a67bf771d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/common_func.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_COMMON_FUNC_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_COMMON_FUNC_H_ + +#include +#include +#include +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int8_t MinInt8(int8_t a, int8_t b); +int8_t MaxInt8(int8_t a, int8_t b); +void ReluFp32(float *data, int ele_num); +void Relu6Fp32(float *data, int ele_num); +void PostFuncInt8(const int *in, const int *bias, int8_t *out, int oc, int plane, int plane8, int32_t multiplier, + int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi); +void IndirectGemmFp32_8x8(float *output, const float *input, const float *weight, const float *bias, size_t step, + size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4, size_t relu, + size_t relu6); +void IndirectGemmFp32_Comm(float *output, const float *input, const float *weight, size_t ic4, size_t hw, size_t oc, + size_t offset); +void IndirectGemmFp32(float *output, const float *input, const float *weight, const float *bias, size_t step, int ic4, + int output_channel, size_t offset, size_t relu, size_t relu6); + +#ifdef ENABLE_ARM64 +void BiasAdd(const float *bias, float *data, size_t oc4, size_t plan_size); +void BiasAddRelu6(const float *bias, float *data, size_t oc4, size_t plan_size); +void BiasAddRelu(const float *bias, float *data, size_t oc4, size_t plan_size); +void Relu6(float *data, size_t element4); +void Relu(float *data, size_t element4); +#endif + +#ifdef __cplusplus +} +#endif + +#endif /* MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_COMMON_FUNC_H_ */ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/concat_parameter.h b/mindspore/lite/src/runtime/kernel/arm/opclib/concat_parameter.h new file mode 100644 index 00000000000..55a6dd0e3c0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/concat_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_CONCAT_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_CONCAT_PARAMETER_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" +struct ConcatParameter { + OpParameter op_parameter_; + ConcatQuantArg *concat_quant_arg_; + int axis_; + int thread_count_; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_CONCAT_PARAMETER_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/conv_parameter.h b/mindspore/lite/src/runtime/kernel/arm/opclib/conv_parameter.h new file mode 100644 index 00000000000..8d5ad1b708f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/conv_parameter.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_CONV_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_CONV_PARAMETER_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/quantization/quantize.h" + +struct ConvParameter { + OpParameter op_parameter_; + ConvQuantArg conv_quant_arg_; + int kernel_h_; + int kernel_w_; + int stride_h_; + int stride_w_; + int dilation_h_; + int dilation_w_; + int pad_h_; + int pad_w_; + int pad_u_; + int pad_d_; + int pad_l_; + int pad_r_; + int group_; + int n_dim_; + int input_batch_; + int input_h_; + int input_w_; + int input_channel_; + int output_batch_; + int output_h_; + int output_w_; + int output_channel_; + int thread_num_; + int input_unit_; + int output_unit_; + bool is_relu_; + bool is_relu6_; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_CONV_PARAMETER_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/errorcode.h b/mindspore/lite/src/runtime/kernel/arm/opclib/errorcode.h new file mode 100644 index 00000000000..3d6eeb80b27 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/errorcode.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ERRORCODE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ERRORCODE_H_ + +enum ErrorCodeCommonEnum { + OPCLIB_OK = 0, + OPCLIB_ERR = 1, + OPCLIB_NULL_PTR, + OPCLIB_PARAM_INVALID, + OPLIB_COMMON_END = 9999 +}; + +enum ErrorCodeFp32OpEnum { + OPCLIB_ERRCODE_OP_FP32_START = 10000, + OPCLIB_ERRCODE_STRASSEN_RECURSION_MALLOC, + OPCLIB_ERRCODE_REVERSE_MALLOC, + OPCLIB_ERRCODE_SQRT_NEGATIVE, + OPCLIB_ERRCODE_RSQRT_NEGATIVE_OR_ZERO, + OPCLIB_ERRCODE_LOG_NEGATIVE_OR_ZERO, + OPCLIB_ERRCODE_DIVISOR_ZERO, + OPCLIB_ERRCODE_INDEX_OUT_OF_RANGE, + OPCLIB_ERRCODE_OP_FP32_END = 19999 +}; + +enum ErrorCodeFp16OpEnum { OPCLIB_ERRCODE_OP_FP16_START = 20000, OPCLIB_ERRCODE_OP_FP16_END = 29999 }; + +enum ErrorCodeUint8OpEnum { OPCLIB_ERRCODE_OP_UINT8_START = 30000, OPCLIB_ERRCODE_OP_UINT8_END = 39999 }; + +enum ErrorCodeInt8OpEnum { OPCLIB_ERRCODE_OP_INT8_START = 40000, OPCLIB_ERRCODE_OP_INT8_END = 49999 }; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ERRORCODE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/flatten.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/flatten.cc new file mode 100644 index 00000000000..a78c0397d9b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/flatten.cc @@ -0,0 +1,22 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/opclib/flatten.h" +#include + +void Flatten(const void *input, void *output, FlattenParameter *flatten_param) { + memcpy(output, input, flatten_param->size); +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/flatten.h b/mindspore/lite/src/runtime/kernel/arm/opclib/flatten.h new file mode 100644 index 00000000000..fa8267b5594 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/flatten.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FLATTEN_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FLATTEN_H_ +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct FlattenParameter { + OpParameter op_parameter_; + int size; +}; + +void Flatten(const void *input, void *output, FlattenParameter *flatten_param); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FLATTEN_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.cc new file mode 100644 index 00000000000..50c57f31329 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.cc @@ -0,0 +1,219 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/opclib/fp16/conv_fp16.h" +#include +#include "src/runtime/kernel/arm/opclib/pack.h" +#include "src/runtime/kernel/arm/opclib/winograd_transform.h" + +extern "C" { +#ifdef ENABLE_ARM64 +#ifdef ENABLE_FP16 +void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, + size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, + size_t relu6); +#endif +#endif +} + +#ifdef ENABLE_FP16 +#ifndef ENABLE_NEON +void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, + size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC4, size_t relu, + size_t relu6) { + int tile_n = 16; + for (int i = 0; i < out_channel; i++) { + int oc8_block = i / 8; + int oc8_res = i % 8; + int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res; + for (int k = 0; k < tile_n; k++) { + int input_tile_offset = k * C4NUM; + int out_tile_offset = i + k * out_channel; + + float16_t tmp_out = 0; + for (int n = 0; n < step; n++) { + int input_kw_offset = input_tile_offset + n * tile_n * ic4 * C4NUM; + int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C8NUM; + for (int j = 0; j < ic4; j++) { + int input_ic4_offset = input_kw_offset + j * tile_n * C4NUM; + int weight_ic4_offset = weight_kw_offset + j * C4NUM * C8NUM; + for (int m = 0; m < C4NUM; m++) { + int input_c4_offset = input_ic4_offset + m; + int weight_c4_offset = weight_ic4_offset + m * C8NUM; + tmp_out += (input + input_c4_offset)[0] * (weight + weight_c4_offset)[0]; + } + } + } + + (output + out_tile_offset)[0] = tmp_out; + } + } +} + +void IndirectGemmFp16_16x8_tmp(float16_t *output, float16_t *input, float16_t *weight, const float16_t *bias, + size_t step, size_t ic4, size_t output_channel, size_t offset, size_t mode, + size_t writeC4, size_t relu, size_t relu6) { + int tile_num = 16; + if (mode) { + for (int i = 0; i < tile_num; i++) { + int input_tile_offset = i * C4NUM; + int output_tile_offset = i * output_channel * 36; + for (int j = 0; j < output_channel; j++) { + int oc8_block = j / 8; + int oc8_res = j % 8; + int weight_oc_offset = oc8_block * 36 * ic4 * C4NUM * 8 + oc8_res; + // todo nc4hw4 -> nhwc + int out_oc_offset = output_tile_offset + oc8_block * 36 * C8NUM + oc8_res; + + for (int n = 0; n < step; n++) { + int input_kw_offset = input_tile_offset + n * ic4 * C4NUM * tile_num; + int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * 8; + int output_kw_offset = out_oc_offset + n * C8NUM; + float16_t acc = 0; + + for (int k = 0; k < ic4; k++) { + int input_ic4_offset = input_kw_offset + k * tile_num * C4NUM; + int weight_ic4_offset = weight_kw_offset + k * C4NUM * 8; + for (int m = 0; m < 4; m++) { + int input_ic_offset = input_ic4_offset + m; + int weight_ic_offset = weight_ic4_offset + m * 8; + acc += (weight + weight_ic_offset)[0] * (input + input_ic_offset)[0]; + } + } + + (output + output_kw_offset)[0] = acc; + } + } + } + } else { + } +} +#endif + +// fp16 convolution common (im2col+gemm) +void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, + float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param) { + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + int out_channel = conv_param->output_channel_; + + // todo + int thread_count = conv_param->thread_num_; + int tile_n = 16; + int output_count = out_h * out_w; + int output_tile_count = UP_DIV(output_count, tile_n); + + int channel_block = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int unit_size = kernel_plane * channel_block * C4NUM; + int packed_input_size = output_tile_count * tile_n * unit_size; + + // we accumulate 4 channels per time for input blocks + int ic4 = UP_DIV(in_channel, C4NUM); + int oc8 = UP_DIV(in_channel, C8NUM); + int conv_depth = kernel_h * kernel_w; + // bytes from one output's i-th channel to the next output's i-th channel + // we write 32 bytes per st1 instruction, after which the pointer in register will step 32B forward + + for (int b = 0; b < in_batch; b++) { + int in_batch_offset = b * in_channel * in_h * in_w; + int out_batch_offset = b * out_channel * out_h * out_w; + int gemm_in_batch_offset = b * packed_input_size; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + int start_index = thread_id * tile_n; + int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; + float16_t *gemm_input = + (float *)(packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset); + Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); + + int out_offset = thread_id * tile_n * out_channel + out_batch_offset; + if (real_cal_num == tile_n) { + float16_t *gemm_output = output_data + out_offset; + IndirectGemmFp16_16x8(gemm_output, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, + oc8 * C8NUM * sizeof(float16_t), 0, 0, 0, 0); + } else { + // res part + IndirectGemmFp16_16x8(tmp_out_block, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, + oc8 * C8NUM * sizeof(float16_t), 0, 0, 0, 0); + memcpy(output_data + out_offset, tmp_out_block, real_cal_num * out_channel * sizeof(float16_t)); + } + } + } +} + +// fp16 conv3x3 +void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data, + float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out, + int task_id, ConvParameter *conv_param) { + // todo + int thread_count = conv_param->thread_num_; + int tile_num = 16; + int output_unit = 4; + int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); + int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); + + int output_batch = conv_param->output_batch_; + int output_channel = conv_param->output_channel_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + + int out_w_block = UP_DIV(conv_param->output_w_, C4NUM); + int out_h_block = UP_DIV(conv_param->output_h_, C4NUM); + int output_count = out_w_block * out_h_block; + int output_tile_count = UP_DIV(output_count, tile_num); + + int input_batch = conv_param->input_batch_; + for (int batch = 0; batch < input_batch; batch++) { + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + int start_index = thread_id * tile_num; + int real_cal_num = (output_count - start_index) < tile_num ? (output_count - start_index) : tile_num; + + Conv3x3Fp16InputTransform(input_data, tile_buffer, block_unit_buffer, start_index, real_cal_num, out_w_block, + conv_param); + + IndirectGemmFp16_16x8(tmp_dst_buffer, tile_buffer, transed_weight, NULL, 36, ic4, oc8 * C8NUM, + oc8 * C8NUM * 36 * sizeof(float16_t), 1, 1, 0, 0); + + Conv3x3Fp16OutputTransform(tmp_dst_buffer, tmp_out, bias_data, start_index, real_cal_num, out_w_block, + conv_param); + } + } + + // get real output + // todo + for (int batch = 0; batch < output_batch; batch++) { + int batch_size = batch * output_channel * output_h * output_w; + for (int h = 0; h < output_h; h++) { + for (int w = 0; w < output_w; w++) { + for (int c = 0; c < output_channel; c++) { + int oc8_block = c / C8NUM; + int oc8_res = c % C8NUM; + int src_offset = oc8_block * C8NUM * out_w_block * out_h_block * tile_num + + C8NUM * (h * out_w_block * output_unit + w) + oc8_res; + int dst_offset = (h * output_w + w) * output_channel + c; + (output_data + dst_offset)[0] = (tmp_out + src_offset)[0]; + } + } + } + } +} +#endif + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.h new file mode 100644 index 00000000000..457e483a983 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_CONV_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_CONV_FP16_H_ + +#ifdef ENABLE_FP16 +#include +#endif +#include "src/runtime/kernel/arm/opclib/conv_parameter.h" + +#ifdef ENABLE_FP16 +#ifndef ENABLE_NEON +void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, + size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, + size_t relu6); +#endif + +// fp16 convolution common (im2col+gemm) +void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, + float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param); + +// fp16 conv3x3 +void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16_t *bias_data, float16_t *output_data, + float16_t *tile_buffer, float16_t *block_unit_buffer, float16_t *tmp_dst_buffer, float16_t *tmp_out, + int task_id, ConvParameter *conv_param); +#endif + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP16_CONV_FP16_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/activation.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/activation.h new file mode 100644 index 00000000000..fcf480bd2b2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/activation.h @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ACTIVATION_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ACTIVATION_H_ + +#include +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/errorcode.h" + +struct ActivationParameter { + OpParameter op_parameter_; + int type_; + float alpha_{0.01}; +}; + +inline int Relu(const float *src, int length, float *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : 0; + } + return OPCLIB_OK; +} + +inline int Relu6(const float *src, int length, float *dst) { + for (int i = 0; i < length; ++i) { + if (src[i] < 0) { + dst[i] = 0; + } else { + dst[i] = src[i] > 6.0f ? 6.0f : src[i]; + } + } + return OPCLIB_OK; +} + +inline int LRelu(const float *src, int length, float *dst, float alpha) { + for (int i = 0; i < length; ++i) { + dst[i] = src[i] > (src[i] * alpha) ? src[i] : (src[i] * alpha); + } + return OPCLIB_OK; +} + +inline int Sigmoid(const float *src, int length, float *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = 1.0f / (1.0f + exp(-src[i])); + } + return OPCLIB_OK; +} + +inline int Tanh(const float *src, int length, float *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = 1.0f - 2.0f / (exp(2 * src[i]) + 1); + } + return OPCLIB_OK; +} + +inline int HSwish(const float *src, int length, float *dst) { + for (int i = 0; i < length; ++i) { + float in = src[i]; + float relu6 = MSMIN(MSMAX(in + 3, 0), 6); + dst[i] = in * relu6 / 6; + } + return OPCLIB_OK; +} + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ACTIVATION_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arg_min_max.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arg_min_max.cc new file mode 100644 index 00000000000..e91fbbc0308 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arg_min_max.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/opclib/fp32/arg_min_max.h" +#include + +void GetCalcParameter(const int *shape, int dims_number, int axis, int *pre_axis_count, int *axis_count, + int *after_axis_count) { + *pre_axis_count = 1; + for (int i = 0; i < axis; ++i) { + *pre_axis_count = (*pre_axis_count) * shape[i]; + } + + *axis_count = shape[axis]; + + *after_axis_count = 1; + for (int i = axis + 1; i < dims_number; ++i) { + *after_axis_count = (*after_axis_count) * shape[i]; + } +} + +void ArgMax(const float *input, const int *shape, int dims_number, int axis, bool out_value, float *output) { + int pre_axis_count = 1; + int axis_count = 1; + int after_axis_count = 1; + GetCalcParameter(shape, dims_number, axis, &pre_axis_count, &axis_count, &after_axis_count); + + for (int i = 0; i < pre_axis_count; ++i) { + int64_t output_offset = i * after_axis_count; + int64_t input_offset = output_offset * axis_count; + + for (int j = 0; j < after_axis_count; ++j) { + float value = -FLT_MAX; + float index = 0.0f; + for (int k = 0; k < axis_count; ++k) { + float value_tmp = input[input_offset + k * after_axis_count + j]; + if (value_tmp > value) { + value = value_tmp; + index = k; + } + } + output[output_offset + j] = out_value ? value : index; + } + } +} + +void ArgMin(const float *input, const int *shape, int dims_number, int axis, bool out_value, float *output) { + int pre_axis_count = 1; + int axis_count = 1; + int after_axis_count = 1; + GetCalcParameter(shape, dims_number, axis, &pre_axis_count, &axis_count, &after_axis_count); + + for (int i = 0; i < pre_axis_count; ++i) { + int64_t output_offset = i * after_axis_count; + int64_t input_offset = output_offset * axis_count; + for (int j = 0; j < after_axis_count; ++j) { + float value = FLT_MAX; + float index = 0.0f; + for (int k = 0; k < axis_count; ++k) { + float value_tmp = input[input_offset + k * after_axis_count + j]; + if (value_tmp < value) { + value = value_tmp; + index = k; + } + } + output[output_offset + j] = out_value ? value : index; + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arg_min_max.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arg_min_max.h new file mode 100644 index 00000000000..83e5f6f993e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arg_min_max.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_ARG_MIN_MAX_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_ARG_MIN_MAX_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/opclib/op_base.h" + +// For arg min, arg max. +struct ArgMinMaxParameter { + OpParameter op_parameter_; + int axis_; + int topk_; + int axis_type_; + bool out_value_; + bool keep_dims_; +}; + +void ArgMax(const float *input, const int *shape, int dims_number, int axis, bool out_value, float *output); +void ArgMin(const float *input, const int *shape, int dims_number, int axis, bool out_value, float *output); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_ARG_MIN_MAX_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arithmetic.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arithmetic.cc new file mode 100644 index 00000000000..e7dcbb32b98 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arithmetic.cc @@ -0,0 +1,526 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h" + +int ElementMul(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef USE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vmulq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + output[0] = input0[0] * input1[0]; + output[1] = input0[1] * input1[1]; + output[2] = input0[2] * input1[2]; + output[3] = input0[3] * input1[3]; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] * input1[index]; + } + + return OPCLIB_OK; +} + +int BroadcastMul(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementMul(tile_input0, tile_input1, output, element_size); +} + +int ElementAdd(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef USE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vaddq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + output[0] = input0[0] + input1[0]; + output[1] = input0[1] + input1[1]; + output[2] = input0[2] + input1[2]; + output[3] = input0[3] + input1[3]; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] + input1[index]; + } + return OPCLIB_OK; +} + +int ElementAddInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input0[i] + input1[i]; + } + return OPCLIB_OK; +} + +int BroadcastAdd(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementAdd(tile_input0, tile_input1, output, element_size); +} + +int BroadcastAddInt8(int8_t *input0, int8_t *input1, int8_t *tile_input0, int8_t *tile_input1, int8_t *output, + int element_size, ArithmeticParameter *param) { + TileDimensionsInt8(input0, input1, tile_input0, tile_input1, param); + return ElementAddInt8(tile_input0, tile_input1, output, element_size); +} + +int ElementSub(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef USE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vsubq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + output[0] = input0[0] - input1[0]; + output[1] = input0[1] - input1[1]; + output[2] = input0[2] - input1[2]; + output[3] = input0[3] - input1[3]; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] - input1[index]; + } + return OPCLIB_OK; +} + +int BroadcastSub(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementSub(tile_input0, tile_input1, output, element_size); +} + +// todo c=a/b,if(b==0) +int ElementDiv(float *input0, float *input1, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input1[i] == 0) { + return OPCLIB_ERRCODE_DIVISOR_ZERO; + } + output[i] = input0[i] / input1[i]; + } + return OPCLIB_OK; +} + +int BroadcastDiv(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementDiv(tile_input0, tile_input1, output, element_size); +} + +int ElementFloorMod(float *input0, float *input1, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input1[i] == 0) { + return OPCLIB_ERRCODE_DIVISOR_ZERO; + } + output[i] = input0[i] - floorf(input0[i] / input1[i]) * input1[i]; + } + return OPCLIB_OK; +} + +int BroadcastFloorMod(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementFloorMod(tile_input0, tile_input1, output, element_size); +} + +int ElementFloorDiv(float *input0, float *input1, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input1[i] == 0) { + return OPCLIB_ERRCODE_DIVISOR_ZERO; + } + output[i] = floorf(input0[i] / input1[i]); + } + return OPCLIB_OK; +} + +int BroadcastFloorDiv(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementFloorDiv(tile_input0, tile_input1, output, element_size); +} + +int ElementLogicalAnd(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef USE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vandq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + output[0] = (float)((bool)(input0[0]) & (bool)(input1[0])); + output[1] = (float)((bool)(input0[1]) & (bool)(input1[1])); + output[2] = (float)((bool)(input0[2]) & (bool)(input1[2])); + output[3] = (float)((bool)(input0[3]) & (bool)(input1[3])); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float)((bool)(input0[index]) & (bool)(input1[index])); + } + return OPCLIB_OK; +} + +int ElementSquaredDifference(float *input0, float *input1, float *output, int element_size) { + ElementSub(input0, input1, output, element_size); + return ElementMul(output, output, output, element_size); +} + +int BroadcastSquaredDifference(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + BroadcastSub(input0, input1, tile_input0, tile_input1, output, element_size, param); + return ElementMul(output, output, output, element_size); +} + +int BroadcastLogicalAnd(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementLogicalAnd(tile_input0, tile_input1, output, element_size); +} + +int ElementLogicalOr(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef USE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vorrq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + output[0] = (float)((bool)(input0[0]) | (bool)(input1[0])); + output[1] = (float)((bool)(input0[1]) | (bool)(input1[1])); + output[2] = (float)((bool)(input0[2]) | (bool)(input1[2])); + output[3] = (float)((bool)(input0[3]) | (bool)(input1[3])); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float)((bool)(input0[index]) | (bool)(input1[index])); + } + return OPCLIB_OK; +} + +int BroadcastLogicalOr(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementLogicalOr(tile_input0, tile_input1, output, element_size); +} + +int ElementMaximum(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef USE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vmaxq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + output[0] = input0[0] > input1[0] ? input0[0] : input1[0]; + output[1] = input0[1] > input1[1] ? input0[1] : input1[1]; + output[2] = input0[2] > input1[2] ? input0[2] : input1[2]; + output[3] = input0[3] > input1[3] ? input0[3] : input1[3]; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] > input1[index] ? input0[index] : input1[index]; + } + return OPCLIB_OK; +} + +int BroadcastMaximum(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementMaximum(tile_input0, tile_input1, output, element_size); +} + +int ElementMinimum(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef USE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vminq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + output[0] = input0[0] > input1[0] ? input1[0] : input0[0]; + output[1] = input0[1] > input1[1] ? input1[1] : input0[1]; + output[2] = input0[2] > input1[2] ? input1[2] : input0[2]; + output[3] = input0[3] > input1[3] ? input1[3] : input0[3]; +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] > input1[index] ? input1[index] : input0[index]; + } + return OPCLIB_OK; +} + +int BroadcastMinimum(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementMinimum(tile_input0, tile_input1, output, element_size); +} + +int ElementNotEqual(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; +#ifdef USE_NEON + float32x4_t vtrue = {1, 1, 1, 1}; + float32x4_t vfalse = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef USE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vbslq_f32(vceqq_fp32(vin0, vin1), vfalse, vtrue); + vst1q_f32(output, vout); +#else + output[0] = (float)(input0[0] != input1[0]); + output[1] = (float)(input0[1] != input1[1]); + output[2] = (float)(input0[2] != input1[2]); + output[3] = (float)(input0[3] != input1[3]); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float)(input0[index] != input1[index]); + } + return OPCLIB_OK; +} + +int BroadcastNotEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementNotEqual(tile_input0, tile_input1, output, element_size); +} + +int ElementEqual(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; +#ifdef USE_NEON + float32x4_t vtrue = {1, 1, 1, 1}; + float32x4_t vfalse = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef USE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vbslq_f32(vceqq_fp32(vin0, vin1), vtrue, vfalse); + vst1q_f32(output, vout); +#else + output[0] = (float)(input0[0] == input1[0]); + output[1] = (float)(input0[1] == input1[1]); + output[2] = (float)(input0[2] == input1[2]); + output[3] = (float)(input0[3] == input1[3]); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float)(input0[index] == input1[index]); + } + return OPCLIB_OK; +} + +int BroadcastEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementEqual(tile_input0, tile_input1, output, element_size); +} + +int ElementLess(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; +#ifdef USE_NEON + float32x4_t vtrue = {1, 1, 1, 1}; + float32x4_t vfalse = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef USE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vbslq_f32(vcltq_fp32(vin0, vin1), vtrue, vfalse); + vst1q_f32(output, vout); +#else + output[0] = (float)(input0[0] < input1[0]); + output[1] = (float)(input0[1] < input1[1]); + output[2] = (float)(input0[2] < input1[2]); + output[3] = (float)(input0[3] < input1[3]); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float)(input0[index] < input1[index]); + } + return OPCLIB_OK; +} + +int BroadcastLess(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementLess(tile_input0, tile_input1, output, element_size); +} + +int ElementLessEqual(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; +#ifdef USE_NEON + float32x4_t vtrue = {1, 1, 1, 1}; + float32x4_t vfalse = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef USE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vbslq_f32(vcleq_fp32(vin0, vin1), vtrue, vfalse); + vst1q_f32(output, vout); +#else + output[0] = (float)(input0[0] <= input1[0]); + output[1] = (float)(input0[1] <= input1[1]); + output[2] = (float)(input0[2] <= input1[2]); + output[3] = (float)(input0[3] <= input1[3]); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float)(input0[index] <= input1[index]); + } + return OPCLIB_OK; +} + +int BroadcastLessEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementLessEqual(tile_input0, tile_input1, output, element_size); +} + +int ElementGreater(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; +#ifdef USE_NEON + float32x4_t vtrue = {1, 1, 1, 1}; + float32x4_t vfalse = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef USE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vbslq_f32(vcgtq_fp32(vin0, vin1), vtrue, vfalse); + vst1q_f32(output, vout); +#else + output[0] = (float)(input0[0] > input1[0]); + output[1] = (float)(input0[1] > input1[1]); + output[2] = (float)(input0[2] > input1[2]); + output[3] = (float)(input0[3] > input1[3]); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float)(input0[index] > input1[index]); + } + return OPCLIB_OK; +} + +int BroadcastGreater(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementGreater(tile_input0, tile_input1, output, element_size); +} + +int ElementGreaterEqual(float *input0, float *input1, float *output, int element_size) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; +#ifdef USE_NEON + float32x4_t vtrue = {1, 1, 1, 1}; + float32x4_t vfalse = {0, 0, 0, 0}; +#endif + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef USE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vbslq_f32(vcgeq_fp32(vin0, vin1), vtrue, vfalse); + vst1q_f32(output, vout); +#else + output[0] = (float)(input0[0] >= input1[0]); + output[1] = (float)(input0[1] >= input1[1]); + output[2] = (float)(input0[2] >= input1[2]); + output[3] = (float)(input0[3] >= input1[3]); +#endif + input0 += C4NUM; + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = (float)(input0[index] >= input1[index]); + } + return OPCLIB_OK; +} + +int BroadcastGreaterEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param) { + TileDimensions(input0, input1, tile_input0, tile_input1, param); + return ElementGreaterEqual(tile_input0, tile_input1, output, element_size); +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arithmetic.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arithmetic.h new file mode 100644 index 00000000000..6788334803a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arithmetic.h @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARITHMETIC_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARITHMETIC_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/arithmetic_common.h" +#include "src/runtime/kernel/arm/opclib/errorcode.h" + +int ElementMul(float *input0, float *input1, float *output, int element_size); +int BroadcastMul(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param); + +int ElementAdd(float *input0, float *input1, float *output, int element_size); +int BroadcastAdd(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param); +int BroadcastAddInt8(int8_t *input0, int8_t *input1, int8_t *tile_input0, int8_t *tile_input1, int8_t *output, + int element_size, ArithmeticParameter *param); + +int ElementSub(float *input0, float *input1, float *output, int element_size); +int BroadcastSub(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param); + +int ElementDiv(float *input0, float *input1, float *output, int element_size); +int BroadcastDiv(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param); + +int ElementLogicalAnd(float *input0, float *input1, float *output, int element_size); +int BroadcastLogicalAnd(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementLogicalOr(float *input0, float *input1, float *output, int element_size); +int BroadcastLogicalOr(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementMaximum(float *input0, float *input1, float *output, int element_size); +int BroadcastMaximum(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementMinimum(float *input0, float *input1, float *output, int element_size); +int BroadcastMinimum(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementFloorDiv(float *input0, float *input1, float *output, int element_size); +int BroadcastFloorDiv(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementFloorMod(float *input0, float *input1, float *output, int element_size); +int BroadcastFloorMod(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementSquaredDifference(float *input0, float *input1, float *output, int element_size); +int BroadcastSquaredDifference(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementNotEqual(float *input0, float *input1, float *output, int element_size); + +int BroadcastNotEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementEqual(float *input0, float *input1, float *output, int element_size); + +int BroadcastEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementLess(float *input0, float *input1, float *output, int element_size); +int BroadcastLess(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, + ArithmeticParameter *param); + +int ElementLessEqual(float *input0, float *input1, float *output, int element_size); +int BroadcastLessEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementGreater(float *input0, float *input1, float *output, int element_size); +int BroadcastGreater(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); + +int ElementGreaterEqual(float *input0, float *input1, float *output, int element_size); +int BroadcastGreaterEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, + int element_size, ArithmeticParameter *param); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARITHMETIC_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arithmetic_self.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arithmetic_self.cc new file mode 100644 index 00000000000..86ce36d09b1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arithmetic_self.cc @@ -0,0 +1,123 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "src/runtime/kernel/arm/opclib/fp32/arithmetic_self.h" + +// abs: +int ElementAbs(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = fabsf(input[i]); + } + return OPCLIB_OK; +} + +// cos: +int ElementCos(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = cosf(input[i]); + } + return OPCLIB_OK; +} + +// exp: +int ElementExp(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = expf(input[i]); + } + return OPCLIB_OK; +} + +// log: +int ElementLog(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input[i] <= 0) { + return OPCLIB_ERRCODE_LOG_NEGATIVE_OR_ZERO; + } + output[i] = logf(input[i]); + } + return OPCLIB_OK; +} + +// Square +int ElementSquare(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input[i] * input[i]; + } + return OPCLIB_OK; +} + +// Sqrt +int ElementSqrt(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input[i] < 0) { + return OPCLIB_ERRCODE_SQRT_NEGATIVE; + } + output[i] = sqrtf(input[i]); + } + return OPCLIB_OK; +} + +// rsqrt +int ElementRsqrt(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input[i] <= 0) { + return OPCLIB_ERRCODE_RSQRT_NEGATIVE_OR_ZERO; + } + output[i] = 1.f / sqrtf(input[i]); + } + return OPCLIB_OK; +} + +// sin: +int ElementSin(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = sinf(input[i]); + } + return OPCLIB_OK; +} + +// logical_not: +int ElementLogicalNot(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = (float)(!((bool)(input[i]))); + } + return OPCLIB_OK; +} + +// round: +int ElementRound(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = round(input[i]); + } + return OPCLIB_OK; +} + +// floor: +int ElementFloor(float *input, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = floorf(input[i]); + } + return OPCLIB_OK; +} + +int ElementCeil(float *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = ceil(input[i]); + } + return OPCLIB_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arithmetic_self.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arithmetic_self.h new file mode 100644 index 00000000000..44489d628bb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/arithmetic_self.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARITHMETIC_SELF_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARITHMETIC_SELF_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/errorcode.h" + +// For Abs, Cos, Exp, Log, Square, Sqrt, Rsqrt ops. +struct ArithmeticSelfParameter { + OpParameter op_parameter_; +}; + +int ElementAbs(float *input, float *output, int element_size); + +int ElementCos(float *input, float *output, int element_size); + +int ElementExp(float *input, float *output, int element_size); + +int ElementLog(float *input, float *output, int element_size); + +int ElementSquare(float *input, float *output, int element_size); + +int ElementSqrt(float *input, float *output, int element_size); + +int ElementRsqrt(float *input, float *output, int element_size); + +int ElementSin(float *input, float *output, int element_size); + +int ElementLogicalNot(float *input, float *output, int element_size); + +int ElementRound(float *input, float *output, int element_size); + +int ElementFloor(float *input, float *output, int element_size); + +int ElementCeil(float *input, float *output, int number); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ARITHMETIC_SELF_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/batch_to_space.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/batch_to_space.cc new file mode 100644 index 00000000000..4faccbae116 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/batch_to_space.cc @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/batch_to_space.h" +#include "src/runtime/kernel/arm/opclib/arithmetic_common.h" + +void BatchToSpaceNoCropForNHWC(const float *input, float *output, const int *in_shape, int out_n, const int *block) { + int block_h = block[0]; + int block_w = block[1]; + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + size_t stride_h = block_w * out_n; + size_t output_offset = 0; + size_t copy_size = in_c * 4; + size_t in_stride_h = in_w * in_c; + size_t in_stride_n = in_stride_h * in_h; + for (int n = 0; n < out_n; ++n) { + for (int h = 0; h < in_h; ++h) { + size_t h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + for (int w = 0; w < in_w; ++w) { + size_t w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + size_t in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + memcpy(output + output_offset, input + in_offset, copy_size); + output_offset += in_c; + } + } + } + } + } +} + +void BatchToSpaceForNHWC(const float *input, float *output, const int *in_shape, int out_n, const int *block, + const int *crops) { + int block_h = block[0]; + int block_w = block[1]; + int in_n = in_shape[0]; + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + int h_start = crops[0] / block_h; + int h_valid_begin = crops[0]; + int h_end = MSMIN((in_h * block_h - crops[1]) / block_h + 1, in_h); + int h_valid_end = in_h * block_h - crops[1] - 1; + int w_start = crops[2] / block_w; + int w_valid_begin = crops[2]; + int w_end = MSMIN((in_w * block_w - crops[3]) / block_w + 1, in_w); + int w_valid_end = in_w * block_w - crops[3] - 1; + + size_t stride_h = block_w * out_n; + size_t output_offset = 0; + size_t copy_size = in_c * 4; + size_t in_stride_h = in_w * in_c; + size_t in_stride_n = in_stride_h * in_h; + for (int n = 0; n < out_n; ++n) { + for (int h = h_start; h < h_end; ++h) { + size_t h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + size_t h_index = h * block_h + bh; + if (h_index < h_valid_begin || h_index > h_valid_end) { + continue; + } + for (int w = w_start; w < w_end; ++w) { + size_t w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + size_t w_index = w * block_w + bw; + if (w_index < w_valid_begin || w_index > w_valid_end) { + continue; + } + size_t in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + memcpy(output + output_offset, input + in_offset, copy_size); + output_offset += in_c; + } + } + } + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/batch_to_space.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/batch_to_space.h new file mode 100644 index 00000000000..a008222474a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/batch_to_space.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_BATCH_TO_SPACE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_BATCH_TO_SPACE_H_ +#include "src/runtime/kernel/arm/opclib/op_base.h" + +#define BATCH_TO_SPACE_BLOCK_SHAPE_SIZE 2 +#define BATCH_TO_SPACE_CROPS_SIZE 4 + +struct BatchToSpaceParameter { + OpParameter op_parameter_; + int32_t block_shape_[BATCH_TO_SPACE_BLOCK_SHAPE_SIZE]; + int32_t crops_[BATCH_TO_SPACE_CROPS_SIZE]; +}; + +void BatchToSpaceNoCropForNHWC(const float *input, float *output, const int *in_shape, int out_n, const int *block); +void BatchToSpaceForNHWC(const float *input, float *output, const int *in_shape, int out_n, const int *block, + const int *crops); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_BATCH_TO_SPACE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/broadcast_to.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/broadcast_to.cc new file mode 100644 index 00000000000..747b286f023 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/broadcast_to.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/broadcast_to.h" +#include +#include "src/runtime/kernel/arm/opclib/op_base.h" + +void PadBroadcastShapeInfo(BroadcastShapeInfo *shape_info) { + if (shape_info->input_shape_size_ < DIMENSION_4D) { + int input_shape_tmp[DIMENSION_4D]; + for (int i = 0; i < shape_info->input_shape_size_; ++i) { + input_shape_tmp[i] = shape_info->input_shape_[i]; + } + int input_shape_index = shape_info->input_shape_size_ - 1; + for (int i = DIMENSION_4D - 1; i >= 0; --i) { + if (input_shape_index >= 0) { + shape_info->input_shape_[i] = input_shape_tmp[input_shape_index--]; + } else { + shape_info->input_shape_[i] = 1; + } + } + } + if (shape_info->output_shape_size_ < DIMENSION_4D) { + int output_shape_tmp[DIMENSION_4D]; + for (int i = 0; i < shape_info->output_shape_size_; ++i) { + output_shape_tmp[i] = shape_info->output_shape_[i]; + } + int output_shape_index = shape_info->output_shape_size_ - 1; + for (int i = DIMENSION_4D - 1; i >= 0; --i) { + if (output_shape_index >= 0) { + shape_info->output_shape_[i] = output_shape_tmp[output_shape_index--]; + } else { + shape_info->output_shape_[i] = 1; + } + } + } +} + +int BroadcastTo(const float *input, BroadcastShapeInfo *shape_info, float *output) { + if (shape_info->input_shape_size_ > DIMENSION_4D || shape_info->output_shape_size_ > DIMENSION_4D) { + return -1; + } + PadBroadcastShapeInfo(shape_info); + size_t input_dim_offset[DIMENSION_4D - 1]; + input_dim_offset[2] = shape_info->input_shape_[3] * 4; + input_dim_offset[1] = input_dim_offset[2] * shape_info->input_shape_[2]; + input_dim_offset[0] = input_dim_offset[1] * shape_info->input_shape_[1]; + size_t output_dim_offset[DIMENSION_4D - 1]; + output_dim_offset[2] = shape_info->output_shape_[3] * 4; + output_dim_offset[1] = output_dim_offset[2] * shape_info->output_shape_[2]; + output_dim_offset[0] = output_dim_offset[1] * shape_info->output_shape_[1]; + uint8_t *in_base = (uint8_t *)input; + uint8_t *out_base = (uint8_t *)(output); + for (int32_t dim0 = 0; dim0 < shape_info->input_shape_[0]; ++dim0) { + for (int32_t dim1 = 0; dim1 < shape_info->input_shape_[1]; ++dim1) { + for (int32_t dim2 = 0; dim2 < shape_info->input_shape_[2]; ++dim2) { + if (shape_info->input_shape_[3] == shape_info->output_shape_[3]) { + memcpy(out_base + output_dim_offset[0] * dim0 + output_dim_offset[1] * dim1 + + output_dim_offset[2] * dim2, + in_base + input_dim_offset[0] * dim0 + input_dim_offset[1] * dim1 + + input_dim_offset[2] * dim2, input_dim_offset[2]); + } else { + for (int32_t dim3 = 0; dim3 < shape_info->output_shape_[3]; ++dim3) { + memcpy(out_base + output_dim_offset[0] * dim0 + output_dim_offset[1] * dim1 + + output_dim_offset[2] * dim2 + dim3 * 4, + in_base + input_dim_offset[0] * dim0 + input_dim_offset[1] * dim1 + + input_dim_offset[2] * dim2, 4); + } + } + } + if (shape_info->input_shape_[2] != shape_info->output_shape_[2]) { + for (int32_t dim2 = 0; dim2 < shape_info->output_shape_[2]; ++dim2) { + memcpy(out_base + output_dim_offset[0] * dim0 + output_dim_offset[1] * dim1 + + dim2 * output_dim_offset[2], + out_base + output_dim_offset[0] * dim0 + output_dim_offset[1] * dim1, + output_dim_offset[2]); + } + } + } + if (shape_info->input_shape_[1] != shape_info->output_shape_[1]) { + for (int32_t dim1 = 0; dim1 < shape_info->output_shape_[1]; ++dim1) { + memcpy(out_base + output_dim_offset[0] * dim0 + output_dim_offset[1] * dim1, + out_base + output_dim_offset[0] * dim0, output_dim_offset[1]); + } + } + } + if (shape_info->input_shape_[0] != shape_info->output_shape_[0]) { + for (int32_t dim0 = 0; dim0 < shape_info->output_shape_[0]; ++dim0) { + memcpy(out_base + output_dim_offset[0] * dim0, out_base, output_dim_offset[0]); + } + } + return 0; +} + + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/broadcast_to.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/broadcast_to.h new file mode 100644 index 00000000000..a2ea23bfbbc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/broadcast_to.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_BROADCAST_TO_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_BROADCAST_TO_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/opclib/op_base.h" + +#define BROADCAST_TO_SHAPE_MAX_SIZE 4 + +struct BroadcastToParameter { + OpParameter op_parameter_; + int shape_[BROADCAST_TO_SHAPE_MAX_SIZE]; + size_t shape_size_; +}; + +struct BroadcastShapeInfo { + int input_shape_[BROADCAST_TO_SHAPE_MAX_SIZE]; + int input_shape_size_; + int output_shape_[BROADCAST_TO_SHAPE_MAX_SIZE]; + int output_shape_size_; +}; + +int BroadcastTo(const float *input, BroadcastShapeInfo *shape_info, float *output); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_BROADCAST_TO_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/cast.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/cast.cc new file mode 100644 index 00000000000..d379200c8f6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/cast.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/cast.h" + +void Uint8ToFloat32(const uint8_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)input[i]; + } +} + +void Uint8ToInt8(const uint8_t *input, int8_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int8_t)(input[i] - 128); + } +} + +void Int8ToUint8(const int8_t *input, uint8_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (uint8_t)(input[i] + 128); + } +} + +void Int32ToFloat32(const int32_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)input[i]; + } +} + +#ifdef ENABLE_FP16 +void Float32ToFloat16(const float *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +void Float16ToFloat32(const float16_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)input[i]; + } +} +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/cast.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/cast.h new file mode 100644 index 00000000000..70135ee8fd9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/cast.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_CAST_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_CAST_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/opclib/op_base.h" + +// For cast. +struct CastParameter { + OpParameter op_parameter_; + int src_type_; + int dst_type_; +}; + +void Uint8ToFloat32(const uint8_t *input, float *output, int number); +void Uint8ToInt8(const uint8_t *input, int8_t *output, int number); +void Int8ToUint8(const int8_t *input, uint8_t *output, int number); +void Int32ToFloat32(const int32_t *input, float *output, int number); +#ifdef ENABLE_FP16 +void Float32ToFloat16(const float *input, float16_t *output, int number); +void Float16ToFloat32(const float16_t *input, float *output, int number); +#endif +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_CAST_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/common_func.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/common_func.cc new file mode 100644 index 00000000000..ce5f1a06fdb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/common_func.cc @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/common_func.h" + +#ifndef ENABLE_ARM +void MatrixAdd(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride, + size_t row, size_t col) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int a_index = c * a_stride + r * C4NUM; + int b_index = c * b_stride + r * C4NUM; + int c_index = c * c_stride + r * C4NUM; + for (int i = 0; i < C4NUM; i++) { + dst[c_index + i] = a_ptr[a_index + i] + b_ptr[b_index + i]; + } + } + } + return; +} + +void MatrixSub(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride, + size_t row, size_t col) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int a_index = c * a_stride + r * C4NUM; + int b_index = c * b_stride + r * C4NUM; + int c_index = c * c_stride + r * C4NUM; + for (int i = 0; i < C4NUM; i++) { + dst[c_index + i] = a_ptr[a_index + i] - b_ptr[b_index + i]; + } + } + } + return; +} +#endif + +void MatrixMultiAdd(float *c11, float *c12, float *c21, float *c22, float *x_ptr, size_t row, size_t col, + size_t c_stride, size_t x_stride) { + /* U2 = P1 + P6 */ + MatrixAdd(x_ptr, c12, c12, x_stride, c_stride, c_stride, row, col); + /* U3 = U2 + P7 */ + MatrixAdd(c12, c21, c21, c_stride, c_stride, c_stride, row, col); + /* U4 = U2 + P5 */ + MatrixAdd(c12, c22, c12, c_stride, c_stride, c_stride, row, col); + /* U7 = U3 + P5 */ + MatrixAdd(c21, c22, c22, c_stride, c_stride, c_stride, row, col); + /* U5 = U4 + P3 */ + MatrixAdd(c12, c11, c12, c_stride, c_stride, c_stride, row, col); + return; +} + +void PostConvFuncFp32(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t stride, bool is_relu, bool is_relu6) { +#ifndef ENABLE_ARM64 + for (int oc = 0; oc < output_channel; oc++) { + int oc4div = oc / 4, oc4mod = oc % 4; + for (int hw = 0; hw < plane_size; hw++) { + int src_index = oc4div * 4 * plane_size + hw * 4 + oc4mod; + int dst_index = hw * stride + oc; + float value = c4_out_ptr[src_index]; + if (bias_ptr != nullptr) { + value = value + bias_ptr[oc]; + } + value = (is_relu) ? (MSMAX(0, value)) : (value); + value = (is_relu6) ? (MSMIN(6, MSMAX(0, value))) : (value); + out_ptr[dst_index] = value; + } + } +#else + int oc4 = UP_DIV(output_channel, C4NUM); + if (bias_ptr != nullptr) { + if (is_relu) { + BiasAddRelu(bias_ptr, out_ptr, oc4, plane_size); + } else if (is_relu6) { + BiasAddRelu6(bias_ptr, out_ptr, oc4, plane_size); + } else { + BiasAdd(bias_ptr, out_ptr, oc4, plane_size); + } + } else { + if (is_relu) { + Relu(out_ptr, oc4 * plane_size); + } else if (is_relu6) { + Relu6(out_ptr, oc4 * plane_size); + } else { + // do nothing + } + } +#endif + return; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/common_func.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/common_func.h new file mode 100644 index 00000000000..0611df88c4c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/common_func.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_COMMON_FUNC_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_COMMON_FUNC_H_ + +#include +#include +#include +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void PostConvFuncFp32(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t stride, bool is_relu, bool is_relu6); +void MatrixAdd(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride, + size_t row, size_t col); +void MatrixSub(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride, + size_t row, size_t col); +void MatrixMultiAdd(float *c11, float *c12, float *c21, float *c22, float *x_ptr, size_t row, size_t col, + size_t c_stride, size_t x_stride); + +#ifdef ENABLE_ARM64 +void BiasAdd(const float *bias, float *data, size_t oc4, size_t plan_size); +void BiasAddRelu6(const float *bias, float *data, size_t oc4, size_t plan_size); +void BiasAddRelu(const float *bias, float *data, size_t oc4, size_t plan_size); +void Relu6(float *data, size_t element4); +void Relu(float *data, size_t element4); +#endif + +#ifdef __cplusplus +} +#endif + +#endif /* MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_COMMON_FUNC_H_ */ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/concat.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/concat.cc new file mode 100644 index 00000000000..c410a96cba7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/concat.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/concat.h" +#include + +void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output) { + int before_axis_size = 1; + for (int i = 0; i < axis; ++i) { + before_axis_size *= inputs_output_shape[0][i]; + } + // sizeof float/int32 + int after_axis_size = 4; + for (size_t i = axis + 1; i < shape_size; ++i) { + after_axis_size *= inputs_output_shape[0][i]; + } + int axis_offset = 0; + uint8_t *dst_base = reinterpret_cast(output); + size_t output_stride = after_axis_size * inputs_output_shape[input_num][axis]; + for (int i = 0; i < input_num; ++i) { + uint8_t *src_base = reinterpret_cast(input[i]); + size_t input_stride = after_axis_size * inputs_output_shape[i][axis]; + for (int j = 0; j < before_axis_size; ++j) { + uint8_t *src = src_base + j * input_stride; + uint8_t *dst = dst_base + j * output_stride + axis_offset * after_axis_size; + memcpy(dst, src, input_stride); + } + axis_offset += inputs_output_shape[i][axis]; + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/concat.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/concat.h new file mode 100644 index 00000000000..e8ed2690eea --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/concat.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_CONCAT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_CONCAT_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_CONCAT_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv.cc new file mode 100644 index 00000000000..cc121dc4b54 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv.cc @@ -0,0 +1,194 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/conv.h" +#include +#include "src/runtime/kernel/arm/opclib/winograd_transform.h" + +// fp32 conv common +void ConvFp32(float *input_data, float *packed_input, float *packed_weight, const float *bias_data, + float *tmp_out_block, float *output_data, int task_id, ConvParameter *conv_param) { + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + int out_channel = conv_param->output_channel_; + int thread_count = conv_param->thread_num_; + int tile_n = 8; + int output_count = out_h * out_w; + int output_tile_count = UP_DIV(output_count, tile_n); + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int unit_size = kernel_plane * ic4 * C4NUM; + int packed_input_size = output_tile_count * tile_n * unit_size; + + // we accumulate 4 channels per time for input blocks + int conv_depth = kernel_h * kernel_w; + // bytes from one output's i-th channel to the next output's i-th channel + // we write 32 bytes per st1 instruction, after which the pointer in register will step 32B forward + size_t output_offset = out_channel * sizeof(float); + + for (int b = 0; b < in_batch; b++) { + int in_batch_offset = b * in_channel * in_h * in_w; + int out_batch_offset = b * out_channel * out_h * out_w; + int gemm_in_batch_offset = b * packed_input_size; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + int start_index = thread_id * tile_n; + int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; + float *gemm_input = packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset; + Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); + + int out_offset = thread_id * tile_n * out_channel + out_batch_offset; + if (real_cal_num == tile_n) { + float *gemm_output = output_data + out_offset; + IndirectGemmFp32_8x8(gemm_output, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, + output_offset, 0, 0, conv_param->is_relu_, conv_param->is_relu6_); + } else { + // res part + IndirectGemmFp32_8x8(tmp_out_block, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, + output_offset, 0, 0, conv_param->is_relu_, conv_param->is_relu6_); + memcpy(output_data + out_offset, tmp_out_block, real_cal_num * out_channel * sizeof(float)); + } + } + } +} + +// fp32 conv1x1 strassen matmul +int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output_data, float *tmp_ptr, + StrassenMatMulParameter matmul_param) { + return StrassenMatmul(input_data, weight_data, output_data, &matmul_param, FP32_STRASSEN_MAX_RECURSION, 0, tmp_ptr); +} + +// fp32 conv winograd +void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data, + TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, + InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func) { + int thread_num = conv_param->thread_num_; + int input_unit = conv_param->input_unit_; + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int out_unit = conv_param->output_unit_; + int out_w_block = UP_DIV(conv_param->output_w_, out_unit); + int out_h_block = UP_DIV(conv_param->output_h_, out_unit); + int output_count = out_w_block * out_h_block; + int output_tile_count = UP_DIV(output_count, TILE_NUM); + int out_channel = conv_param->output_channel_; + int out_batch = conv_param->output_batch_; + int oc4 = UP_DIV(out_channel, C4NUM); + int input_unit_square = input_unit * input_unit; + size_t output_offset = oc4 * C4NUM * input_unit_square * sizeof(float); + bool is_relu = conv_param->is_relu_; + bool is_relu6 = conv_param->is_relu6_; + + float *trans_input = buffer_list[0]; + float *gemm_out = buffer_list[1]; + float *tmp_out_data = buffer_list[2]; + float *tmp_data = buffer_list[3]; + // step 1 : filter transform (pre-processed offline) + // step 2 : input transform (online) + for (int b = 0; b < in_batch; b++) { + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { + int out_tile_index = thread_id * TILE_NUM; + int cal_num = output_count - thread_id * TILE_NUM; + cal_num = cal_num > TILE_NUM ? TILE_NUM : cal_num; + WinogradInputTransform(input_data, trans_input, tmp_data, cal_num, out_tile_index, out_w_block, conv_param, + input_trans_func); + // step 3 : gemm + IndirectGemmFp32_8x8(gemm_out, trans_input, trans_weight, nullptr, input_unit_square, ic4, oc4 * C4NUM, + output_offset, 1, 1, 0, 0); + + // step 4 : output transform + WinogradOutputTransform(gemm_out, tmp_out_data, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + output_trans_func); + } + } + // get real output + for (int batch = 0; batch < out_batch; batch++) { + int batch_size = batch * out_channel * conv_param->output_h_ * conv_param->output_w_; + for (int h = 0; h < conv_param->output_h_; h++) { + for (int w = 0; w < conv_param->output_w_; w++) { + for (int c = 0; c < out_channel; c++) { + int oc4_block = c / C4NUM; + int oc4_res = c % C4NUM; + int src_offset = oc4_block * C4NUM * out_w_block * out_h_block * out_unit * out_unit + + C4NUM * (h * out_w_block * out_unit + w) + oc4_res; + int dst_offset = (h * conv_param->output_w_ + w) * out_channel + c; + (output_data + dst_offset)[0] = (tmp_out_data + src_offset)[0]; + } + } + } + } + + int output_num = out_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_; + if (is_relu) { + ReluFp32(output_data, output_num); + } else if (is_relu6) { + Relu6Fp32(output_data, output_num); + } else { + // do nothing + } +} + +// fp32 conv3x3 +void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data, + TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param) { + int thread_count = conv_param->thread_num_; + int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); + int output_channel = conv_param->output_channel_; + int oc4 = UP_DIV(output_channel, C4NUM); + int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); + int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); + int output_count = out_w_block * out_h_block; + int output_tile_count = UP_DIV(output_count, TILE_NUM); + int input_unit_square = 4 * 4; + bool is_relu = conv_param->is_relu_; + bool is_relu6 = conv_param->is_relu6_; + float *tile_buffer = buffer_list[0]; + float *block_unit_buffer = buffer_list[1]; + float *tmp_dst_buffer = buffer_list[2]; + float *nc4hw4_out = buffer_list[3]; + + int input_batch = conv_param->input_batch_; + for (int batch = 0; batch < input_batch; batch++) { + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + int start_index = thread_id * TILE_NUM; + int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; + Conv3x3Fp32InputTransform(input_data, tile_buffer, block_unit_buffer, start_index, real_cal_num, out_w_block, + conv_param); + + IndirectGemmFp32_8x8(tmp_dst_buffer, tile_buffer, transed_weight, nullptr, input_unit_square, ic4, oc4 * C4NUM, + oc4 * C4NUM * input_unit_square * sizeof(float), 1, 1, 0, 0); + + Conv3x3Fp32OutputTransform(tmp_dst_buffer, nc4hw4_out, bias_data, start_index, real_cal_num, out_w_block, + conv_param); + } + PackNC4HW4ToNHWCFp32(nc4hw4_out, output_data, 1, conv_param->output_h_ * conv_param->output_w_, output_channel); + } + int output_num = oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_; + if (is_relu) { + ReluFp32(output_data, output_num); + } else if (is_relu6) { + Relu6Fp32(output_data, output_num); + } else { + // do nothing + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv.h new file mode 100644 index 00000000000..5c1d0965320 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_CONV_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_CONV_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/opclib/pack.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/common_func.h" +#include "src/runtime/kernel/arm/opclib/conv_parameter.h" +#include "src/runtime/kernel/arm/opclib/fp32/strassen_matmul.h" +#include "src/runtime/kernel/arm/opclib/winograd_utils.h" + +using TmpBufferAddress = float *; + +// fp32 convolution common (im2col+gemm) +void ConvFp32(float *input_data, float *packed_input, float *packed_weight, const float *bias_data, + float *tmp_out_block, float *output_data, int task_id, ConvParameter *conv_param); + +// fp32 conv1x1 strassen matmul +int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output_data, float *tmp_ptr, + StrassenMatMulParameter matmul_param); + +// fp32 convolution winograd +void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data, + TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, + InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func); + +// fp32 conv3x3 +void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data, + TmpBufferAddress *buffer_list, int task_id, + ConvParameter *conv_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_CONV_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv_depthwise.cc new file mode 100644 index 00000000000..3ae161f7b6b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv_depthwise.cc @@ -0,0 +1,351 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/conv_depthwise.h" +#ifdef ENABLE_ARM64 +#include +#endif + +void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) { + int left = 0; + int right = conv_param->output_w_; + int top = 0; + int bottom = conv_param->output_h_; + + for (; left * conv_param->stride_w_ < conv_param->pad_w_; left++) { + } + for (; (right - 1) * conv_param->stride_w_ - conv_param->pad_w_ + conv_param->kernel_w_ * conv_param->dilation_w_ > + conv_param->input_w_ && + right > left; + right--) { + } + for (; top * conv_param->stride_h_ < conv_param->pad_h_; top++) { + } + for (; (bottom - 1) * conv_param->stride_h_ - conv_param->pad_h_ + conv_param->kernel_h_ * conv_param->dilation_h_ > + conv_param->input_h_ && + bottom > top; + bottom--) { + } + sliding->left_ = left; + sliding->right_ = right; + sliding->top_ = top; + sliding->bottom_ = bottom; + sliding->c_block_ = UP_DIV(conv_param->output_channel_, block); + sliding->block_channel_ = UP_DIV(conv_param->output_channel_, block) * block; + + sliding->out_step_ = conv_param->output_h_ * conv_param->output_w_ * sliding->block_channel_; + sliding->out_h_step_ = conv_param->output_w_ * sliding->block_channel_; + sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->block_channel_; // for batch loop + sliding->in_h_step_ = conv_param->input_w_ * sliding->block_channel_; + sliding->in_sh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->stride_h_; // stride H + sliding->in_sw_step_ = sliding->block_channel_ * conv_param->stride_h_; // stride W + sliding->in_kh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->dilation_h_; // kernel H + sliding->in_kw_step_ = sliding->block_channel_ * conv_param->dilation_w_; // kernel W + sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * block; +} + +/*conv depthwise fp32 begin*/ +void DepthwiseBorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width, + int in_kh_step, int in_kw_step, int kernel_w, bool is_relu, bool is_relu6) { + const float *src_kh = src; + const float *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + const float *src_kw = src_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { +#ifdef ENABLE_ARM64 + float32x4_t src_4 = vld1q_f32(src_kw); + float32x4_t weight_4 = vld1q_f32(weight_kw); + float32x4_t dst_4 = vld1q_f32(dst); + dst_4 = vfmaq_f32(dst_4, src_4, weight_4); + vst1q_f32(dst, dst_4); +#else + for (int c = 0; c < C4NUM; c++) { + dst[c] += src_kw[c] * weight_kw[c]; + } +#endif + src_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + for (int c = 0; c < C4NUM; c++) { + dst[c] += bias[c]; + dst[c] = (is_relu) ? (MSMAX(0, dst[c])) : (dst[c]); + dst[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[c]))) : (dst[c]); + } +} + +void DepthwiseBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, + int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + float *dst_h = dst + top * sliding->out_h_step_; + for (int oh = top; oh < bottom; oh++) { + int ih = oh * conv_param->stride_h_ - conv_param->pad_h_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const float *src_h = src + ih * sliding->in_h_step_; + + float *dst_kernel = dst_h + left * sliding->block_channel_; + for (int ow = left; ow < right; ow++) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_w_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const float *src_w = src_h + iw * sliding->block_channel_; + + const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; + + DepthwiseBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, conv_param->is_relu_, + conv_param->is_relu6_); + + dst_kernel += sliding->block_channel_; + } // width loop + dst_h += sliding->out_h_step_; + } // height loop +} + +void DepthwiseCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width, + int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, + int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) { + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + const float *src_kh = src_w; + const float *weight_kh = weight; + for (int kh = 0; kh < kernel_h; kh++) { + const float *src_kw = src_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { +#ifdef ENABLE_ARM64 + float32x4_t src_4 = vld1q_f32(src_kw); + float32x4_t weight_4 = vld1q_f32(weight_kw); + float32x4_t dst_4 = vld1q_f32(dst_w); + dst_4 = vfmaq_f32(dst_4, src_4, weight_4); + vst1q_f32(dst_w, dst_4); +#else + for (int c = 0; c < C4NUM; c++) { + dst_w[c] += src_kw[c] * weight_kw[c]; + } +#endif + src_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + // add biad relu + for (int c = 0; c < C4NUM; c++) { + dst_w[c] += bias[c]; + dst_w[c] = (is_relu) ? (MSMAX(0, dst_w[c])) : (dst_w[c]); + dst_w[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[c]))) : (dst_w[c]); + } + dst_w += block_channel; + src_w += in_sw_step; + } // dst_width loop + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} + +// conv depthwise fp32: sliding window +void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) { + const float *src = input_data; + float *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float *src_data = src + oc * C4NUM; + float *dst_data = dst + oc * C4NUM; + const float *weight = weight_data + oc * sliding->kernel_step_; + const float *bias = bias_data + oc * C4NUM; + DepthwiseBorder(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, + sliding); + DepthwiseBorder(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, + conv_param->output_w_, conv_param, sliding); + DepthwiseBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param, + sliding); + DepthwiseBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->output_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_; + int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_; + const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; + float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; + + DepthwiseCenter(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, + sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, + conv_param->is_relu_, conv_param->is_relu6_); + } + } // output C4 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nc4hwc4 +} +/*conv depthwise fp32 end*/ + +/*deconv depthwise fp32 begin*/ +void DeconvDepthwiseBorderPixel(float *dst, const float *src, const float *weight, int height, int width, + int in_kh_step, int in_kw_step, int kernel_w) { + float *dst_kh = dst; + const float *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + float *dst_kw = dst_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { +#ifdef ENABLE_ARM64 + float32x4_t src_4 = vld1q_f32(src); + float32x4_t weight_4 = vld1q_f32(weight_kw); + float32x4_t dst_4 = vld1q_f32(dst_kw); + dst_4 = vfmaq_f32(dst_4, src_4, weight_4); + vst1q_f32(dst_kw, dst_4); +#else + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src[c] * weight_kw[c]; + } +#endif + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop +} + +void DeconvDepthwiseBorder(float *dst, const float *src, const float *weight, int top, int bottom, int left, int right, + const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + const float *src_h = src + top * sliding->out_h_step_; + for (int ih = top; ih < bottom; ih++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_h_; + int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + float *dst_h = dst + oh * sliding->in_h_step_; + + const float *src_kernel = src_h + left * sliding->block_channel_; + for (int iw = left; iw < right; iw++) { + int ow = iw * conv_param->stride_w_ - conv_param->pad_w_; + int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + float *dst_w = dst_h + ow * sliding->block_channel_; + + const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; + float *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + + DeconvDepthwiseBorderPixel(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_); + src_kernel += sliding->block_channel_; + } // width loop + src_h += sliding->out_h_step_; + } // height loop +} + +void DeconvDepthwiseCenter(float *dst, const float *src, const float *weight, int height, int width, int kernel_h, + int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, + int in_kh_step, int in_kw_step) { + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + float *dst_kh = dst_w; + const float *weight_kh = weight; + for (int kh = 0; kh < kernel_h; kh++) { + float *dst_kw = dst_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { +#ifdef ENABLE_ARM64 + float32x4_t src_4 = vld1q_f32(src_w); + float32x4_t weight_4 = vld1q_f32(weight_kw); + float32x4_t dst_4 = vld1q_f32(dst_kw); + dst_4 = vfmaq_f32(dst_4, src_4, weight_4); + vst1q_f32(dst_kw, dst_4); +#else + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src_w[c] * weight_kw[c]; + } +#endif + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} + +void DeconvDepthwisePostFunc(float *dst, const float *bias, int block_channel, const ConvParameter *conv_param) { + float *dst_k = dst; + for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) { + for (int c = 0; c < C4NUM; c++) { + dst_k[c] += bias[c]; + dst_k[c] = (conv_param->is_relu_) ? (MSMAX(0, dst_k[c])) : (dst_k[c]); + dst_k[c] = (conv_param->is_relu6_) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]); + } + dst_k += block_channel; + } +} + +// deconv depthwise fp32: sliding window +void DeconvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) { + const float *src = input_data; + float *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float *src_data = src + oc * C4NUM; + float *dst_data = dst + oc * C4NUM; + const float *weight = weight_data + oc * sliding->kernel_step_; + const float *bias = bias_data + oc * C4NUM; + DeconvDepthwiseBorder(dst_data, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, sliding); + DeconvDepthwiseBorder(dst_data, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, conv_param->input_w_, + conv_param, sliding); + DeconvDepthwiseBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param, + sliding); + DeconvDepthwiseBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->input_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_; + int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_; + float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; + const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; + + DeconvDepthwiseCenter(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, + sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, + sliding->in_kw_step_); + } + DeconvDepthwisePostFunc(dst_data, bias, sliding->block_channel_, conv_param); + } // output C4 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nc4hwc4 +} +/*deconv depthwise fp32 end*/ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv_depthwise.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv_depthwise.h new file mode 100644 index 00000000000..a952722a9d2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv_depthwise.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_P32_CONV_DEPTHWISE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_P32_CONV_DEPTHWISE_H_ + +#include "src/runtime/kernel/arm/opclib/conv_parameter.h" + +struct SlidingWindowParam { + int left_; + int right_; + int top_; + int bottom_; + int c_block_; + int block_channel_; + int out_step_; + int out_h_step_; + int in_step_; + int in_h_step_; + int in_sh_step_; // stride H + int in_sw_step_; // stride W + int in_kh_step_; // kernel H + int in_kw_step_; // kernel W + int kernel_step_; +}; + +void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block); + +void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); + +void DeconvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_P32_CONV_DEPTHWISE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/crop.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/crop.cc new file mode 100644 index 00000000000..1f1e0135653 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/crop.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/opclib/fp32/crop.h" +#include + +void Pad4DOffset(CropParameter *crop_param) { + int64_t offset_tmp[DIMENSION_4D]; + int axis = crop_param->axis_; + for (int i = 3; i >= 0; --i) { + int offset_index = i - axis; + if (offset_index >= 0) { + offset_tmp[i] = crop_param->offset_[offset_index]; + } else { + offset_tmp[i] = 0; + } + } + for (int i = 0; i < DIMENSION_4D; ++i) { + crop_param->offset_[i] = offset_tmp[i]; + } +} + +void Crop4D(const float *input, float *output, const int *in_shape, const int *out_shape, CropParameter *crop_param) { + Pad4DOffset(crop_param); + size_t in_dim2_stride = in_shape[3]; + size_t in_dim1_stride = in_shape[2] * in_dim2_stride; + size_t in_dim0_stride = in_dim1_stride * in_shape[1]; + size_t offset_3 = crop_param->offset_[3]; + size_t out_offset = 0; + size_t copy_num = out_shape[3]; + size_t copy_size = copy_num * sizeof(float); + size_t in_dim0_end = crop_param->offset_[0] + out_shape[0]; + size_t in_dim1_end = crop_param->offset_[1] + out_shape[1]; + size_t in_dim2_end = crop_param->offset_[2] + out_shape[2]; + for (int i = crop_param->offset_[0]; i < in_dim0_end; ++i) { + size_t dim0_offset = i * in_dim0_stride + offset_3; + for (int j = crop_param->offset_[1]; j < in_dim1_end; ++j) { + size_t dim1_offset = j * in_dim1_stride + dim0_offset; + for (int k = crop_param->offset_[2]; k < in_dim2_end; ++k) { + size_t in_offset = dim1_offset + k * in_dim2_stride; + memcpy(output + out_offset, input + in_offset, copy_size); + out_offset += copy_num; + } + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/crop.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/crop.h new file mode 100644 index 00000000000..3d61355e6cd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/crop.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_CROP_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_CROP_H_ +#include "src/runtime/kernel/arm/opclib/op_base.h" + +#define CROP_OFFSET_MAX_SIZE 4 + +struct CropParameter { + OpParameter op_parameter_; + int64_t offset_[CROP_OFFSET_MAX_SIZE]; + int64_t axis_; +}; + +void Crop4D(const float *input, float *output, const int *in_shape, const int *out_shape, CropParameter *crop_param); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_CROP_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/deconv.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/deconv.cc new file mode 100644 index 00000000000..b2d3a4151f5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/deconv.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/deconv.h" + +void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, int output_channel, int plane) { + /* ichwoc(nhwc) -> oc4 * h * w * incUP4 * 4 */ + int ic_up4 = UP_ROUND(input_channel, C4NUM); + for (int oc = 0; oc < output_channel; oc++) { + int oc4div = oc / C4NUM; + int oc4mod = oc % C4NUM; + for (int ic = 0; ic < input_channel; ic++) { + for (int hw = 0; hw < plane; hw++) { + int src_index = ic * plane * output_channel + hw * output_channel + oc; + int dst_index = oc4div * ic_up4 * plane * C4NUM + hw * ic_up4 * C4NUM + ic * C4NUM + oc4mod; + dst[dst_index] = weight[src_index]; + } + } + } + return; +} + +int DeConvFp32(const float *input, const float *weight, float *output, float *tmp_buffer, + StrassenMatMulParameter matmul_param) { + return StrassenMatmul(input, weight, output, &matmul_param, FP32_STRASSEN_MAX_RECURSION, 0, tmp_buffer); +} + +int DeConvPostFp32(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel, int input_plane, + int kernel_plane, int output_plane, ConvParameter *conv_param) { + int oc4 = UP_DIV(output_channel, C4NUM); + for (int c = 0; c < oc4; c++) { + float *dst_ptr = tmp_c4 + c * output_plane * C4NUM; + const float *src_ptr = src + c * input_plane * kernel_plane * C4NUM; + memset(dst_ptr, 0, output_plane * C4NUM * sizeof(float)); + + for (int ih = 0; ih < conv_param->input_h_; ih++) { + for (int iw = 0; iw < conv_param->input_w_; iw++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_h_; + int ow = iw * conv_param->stride_w_ - conv_param->pad_w_; + + int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + for (int kh = kh_start; kh < kh_end; kh++) { + for (int kw = kw_start; kw < kw_end; kw++) { + int src_index = ih * conv_param->input_w_ * C4NUM + iw * C4NUM + + kh * input_plane * conv_param->kernel_w_ * C4NUM + kw * input_plane * C4NUM; + int dst_index = oh * conv_param->output_w_ * C4NUM + ow * C4NUM + + kh * conv_param->dilation_h_ * conv_param->output_w_ * C4NUM + + kw * conv_param->dilation_w_ * C4NUM; + for (int i = 0; i < C4NUM; i++) { + dst_ptr[dst_index + i] += src_ptr[src_index + i]; + } + } /*kw*/ + } /*kh*/ + } /*iw*/ + } /*ih*/ + } /*oc4*/ + + PostConvFuncFp32(tmp_c4, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_, + conv_param->is_relu6_); + return OPCLIB_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/deconv.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/deconv.h new file mode 100644 index 00000000000..6f5af130121 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/deconv.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_DECONV_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_DECONV_H_ + +#include "src/runtime/kernel/arm/opclib/pack.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/conv_parameter.h" +#include "src/runtime/kernel/arm/opclib/fp32/strassen_matmul.h" + +void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, int output_channel, int plane); + +int DeConvFp32(const float *input, const float *weight, float *output, float *tmp_buffer, + StrassenMatMulParameter matmul_param); + +int DeConvPostFp32(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel, int input_plane, + int kernel_plane, int output_plane, ConvParameter *conv_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_DECONV_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/depth_to_space.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/depth_to_space.cc new file mode 100644 index 00000000000..78680c4938a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/depth_to_space.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/opclib/fp32/depth_to_space.h" +#include "src/runtime/kernel/arm/opclib/arithmetic_common.h" + +void DepthToSpaceForNHWC(const float *input, float *output, int *in_shape, int *out_shape, int shape_size, + int block_size) { + int *in_strides = (int *)(malloc(sizeof(int) * shape_size)); + ComputeStrides(in_shape, in_strides, shape_size); + int *out_strides = (int *)(malloc(sizeof(int) * shape_size)); + ComputeStrides(out_shape, out_strides, shape_size); + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_offset_n = i * in_strides[0]; + size_t out_offset_n = i * out_strides[0]; + for (int j = 0; j < in_shape[1]; ++j) { + size_t in_offset_h = in_offset_n + j * in_strides[1]; + size_t out_offset_h = out_offset_n + j * block_size * out_strides[1]; + for (int k = 0; k < in_shape[2]; ++k) { + size_t in_offset_w = in_offset_h + k * in_strides[2]; + size_t out_offset_w = out_offset_h + k * block_size * out_strides[2]; + for (int l = 0; l < block_size; ++l) { + memcpy(output + out_offset_w + l * out_strides[1], input + in_offset_w + l * block_size * out_strides[2], + block_size * out_strides[2] * 4); + } + } + } + } + free(out_strides); + free(in_strides); +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/depth_to_space.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/depth_to_space.h new file mode 100644 index 00000000000..1239bf16d0a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/depth_to_space.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_DEPTH_TO_SPACE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_DEPTH_TO_SPACE_H_ +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct DepthToSpaceParameter { + OpParameter op_parameter_; + int32_t block_size_; +}; + +void DepthToSpaceForNHWC(const float *input, float *output, int *in_shape, int *out_shape, int shape_size, + int block_size); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_DEPTH_TO_SPACE_H_ + + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/expandDims.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/expandDims.cc new file mode 100644 index 00000000000..a9e78439959 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/expandDims.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/expandDims.h" +#include +#include "src/runtime/kernel/arm/opclib/errorcode.h" + +int ExpandDims(float *input_ptr, float *output_ptr, size_t data_size) { + memcpy(output_ptr, input_ptr, data_size); + return OPCLIB_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/expandDims.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/expandDims.h new file mode 100644 index 00000000000..9d440f2542b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/expandDims.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_EXPANDDIMS_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_EXPANDDIMS_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct ExpandDimsParameter { + OpParameter op_parameter_; + int dim_; +}; + +int ExpandDims(float *input_ptr, float *output_ptr, size_t data_size); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_EXPANDDIMS_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/fill.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/fill.cc new file mode 100644 index 00000000000..b898dcfe81b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/fill.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/fill.h" + +int Fill(float *output, int size, float data) { + for (int i = 0; i < size; ++i) { + output[i] = data; + } + return OPCLIB_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/fill.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/fill.h new file mode 100644 index 00000000000..2b53be39a59 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/fill.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FILL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FILL_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/errorcode.h" + +#define FILL_DIMS_MAX_SIZE 4 + +struct FillParameter { + OpParameter op_parameter_; + int dims_[FILL_DIMS_MAX_SIZE]; + int num_dims_; +}; + +int Fill(float *output, int size, float data); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FILL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/gather.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/gather.cc new file mode 100644 index 00000000000..520b294bd25 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/gather.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/gather.h" +#include + +inline int Stride(int *shape, int rank, int index) { + int i, stride = 1; + for (i = index + 1; i < rank; ++i) { + stride *= shape[i]; + } + return stride; +} + +int Gather(float *input, int outer_size, int inner_size, int limit, int *indices, int indices_element_size, + float *output) { + int i, m; + for (m = 0; m < outer_size; ++m) { + auto inputm = input + inner_size * m * limit; + auto outputm = output + inner_size * m * indices_element_size; + for (i = 0; i < indices_element_size; ++i) { + if (indices[i] < 0 || indices[i] > limit) { + return -1; + } + memcpy(outputm + i * inner_size, inputm + indices[i] * inner_size, sizeof(float) * inner_size); + } + } + return 0; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/gather.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/gather.h new file mode 100644 index 00000000000..9641407384c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/gather.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_GATHER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_GATHER_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct GatherParameter { + OpParameter op_parameter_; + int axis_; + int batchDims_; +}; + +int Gather(float *input, int outer_size, int inner_size, int limit, int *indices, int indices_element_size, + float *output); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_GATHER_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/gatherNd.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/gatherNd.cc new file mode 100644 index 00000000000..00e715751ca --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/gatherNd.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/gatherNd.h" +#include +#include "src/runtime/kernel/arm/opclib/errorcode.h" + +int GatherNd(float *input, float *output, int *in_offset, int area, int count) { + int i = 0; + for (i = 0; i < count; i++) { + (void)memcpy(output + area * i, input + in_offset[i], area * sizeof(float)); + } + return OPCLIB_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/gatherNd.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/gatherNd.h new file mode 100644 index 00000000000..30d92ede398 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/gatherNd.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_GATHERND_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_GATHERND_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct GatherNdParameter { + OpParameter op_parameter_; + int batchDims_; +}; + +int GatherNd(float *input, float *output, int *in_offset, int area, int count); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_GATHERND_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/local_response_norm.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/local_response_norm.cc new file mode 100644 index 00000000000..460f634b819 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/local_response_norm.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/local_response_norm.h" + +int LocalResponseNorm(float *input_ptr, int out_size, int channel, float *output_ptr, int depth_radius, float bias, + float alpha, float beta) { + int i, j, k; + int left, right; + + for (i = 0; i < out_size; i++) { + float *in_data = input_ptr + i * channel; + float *out_data = output_ptr + i * channel; + + for (j = 0; j < channel; j++) { + left = MSMAX(0, j - depth_radius); + right = MSMIN(channel - 1, j + depth_radius); + + float sum = 0.0; + for (k = left; k <= right; k++) { + const float in_val = in_data[k]; + sum += in_val * in_val; + } + out_data[j] = in_data[j] * (float)(pow((double)(sum * alpha + bias), -beta)); + } + } + return 0; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/local_response_norm.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/local_response_norm.h new file mode 100644 index 00000000000..f3a989343d2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/local_response_norm.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_LOCAL_RESPONSE_NORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_LOCAL_RESPONSE_NORM_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct LocalResponseNormParameter { + OpParameter op_parameter_; + int depth_radius_; + float bias_; + float alpha_; + float beta_; +}; + +int LocalResponseNorm(float *input_ptr, int out_size, int channel, float *output_ptr, int depth_radius, float bias, + float alpha, float beta); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_LOCAL_RESPONSE_NORM_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.cc new file mode 100644 index 00000000000..5e2ebe7bc2f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/matmul.h" + +void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col) { + for (int r = 0; r < row; r++) { + float *src = src_ptr + r * col; + for (int c = 0; c < col; c++) { + int cd8 = c / 8; + int cm8 = c % 8; + dst_ptr[cd8 * 8 * row + r * 8 + cm8] = src[c]; + } + } + return; +} + +void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, int row, int col) { + for (int r = 0; r < row; r++) { + int rd8 = r / 8; + int rm8 = r % 8; + for (int c = 0; c < col; c++) { + dst_ptr[rd8 * col * 8 + c * 8 + rm8] = src_ptr[r * col + c]; + } + } + return; +} + +void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, int row, int col) { + int row8 = UP_ROUND(row, 8); + for (int c = 0; c < col; c++) { + int cd8 = c / 8; + int cm8 = c % 8; + for (int r = 0; r < row; r++) { + dst_ptr[r * col + c] = src_ptr[cd8 * row8 * 8 + r * 8 + cm8]; + } + } +} + +void MatMul8x8(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int deep, + int row_8_, int col_8_) { + /* col8-major * row8-major => col8x8-major */ + for (int row = 0; row < row_8_; row++) { + for (int col = 0; col < col_8_; col++) { + int r8div = row / 8, r8mod = row % 8; + int c8div = col / 8, c8mod = col % 8; + size_t ci = c8div * row_8_ * 8 + row * 8 + c8mod; + float value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r8div * deep * 8 + d * 8 + r8mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + value += bias[col]; + value = MSMIN(maxf, value); + value = MSMAX(minf, value); + c[ci] = value; + } + } +} + +void MatMul(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int deep, int row_8_, + int col_8_) { + MatMul8x8(a, b, c, bias, maxf, minf, deep, row_8_, col_8_); +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.h new file mode 100644 index 00000000000..97c6db417ee --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_MATMUL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_MATMUL_H_ + +#include "src/runtime/kernel/arm/opclib/errorcode.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/matmul.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void MatMul(const float *a, const float *b, float *c, const float *bias, float maxf, float minf, int depth, int row, + int col); +void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, int row, int col); +void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, int row, int col); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_MATMUL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/one_hot.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/one_hot.cc new file mode 100644 index 00000000000..bef74d51d72 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/one_hot.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/one_hot.h" +#include "src/runtime/kernel/arm/opclib/errorcode.h" + +int OneHot(const int *indices, float *output, const OneHotParameter *one_hot_param, const int tid, + const int thread_num) { + if (indices == nullptr || one_hot_param == nullptr || output == nullptr) { + return OPCLIB_NULL_PTR; + } + + int outer_size = one_hot_param->outer_size_; + int inner_size = one_hot_param->inner_size_; + int depth = one_hot_param->depth_; + float on_value = one_hot_param->on_value_; + float off_value = one_hot_param->off_value_; + int i, j, k; + for (i = tid; i < outer_size; i += thread_num) { + float *output_ptr = output + i * depth * inner_size; + for (k = 0; k < depth; k++) { + for (j = 0; j < inner_size; j++) { + *output_ptr = off_value; + int index = indices[i * inner_size + j]; + if (index >= depth) { + return OPCLIB_ERRCODE_INDEX_OUT_OF_RANGE; + } + if (index == k) { + *output_ptr = on_value; + } + output_ptr++; + } + } + } + return OPCLIB_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/one_hot.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/one_hot.h new file mode 100644 index 00000000000..48bc0082d51 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/one_hot.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_ONE_HOT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_ONE_HOT_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct OneHotParameter { + OpParameter op_parameter_; + int axis_; + int depth_; + float on_value_; + float off_value_; + int outer_size_; + int inner_size_; +}; + +int OneHot(const int *indices, float *output, const OneHotParameter *one_hot_param, const int tid, + const int thread_num); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_ONE_HOT_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/pooling.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/pooling.cc new file mode 100644 index 00000000000..084e7d42d51 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/pooling.cc @@ -0,0 +1,210 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/pooling.h" +#include + +void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int c4 = UP_DIV(channel, C4NUM); + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + int thread_num = pooling_param->thread_num_; + // input channel is equal to output channel + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + index * channel; + for (int j = 0; j < c4 - 1; j++) { + int in_channel_offset = in_batch_offset + j * C4NUM; + int out_channel_offset = out_plane_offset + j * C4NUM; +#ifdef ENABLE_NEON + float32x4_t tmp_avg = vdupq_n_f32(0); +#else + float tmp_avg1 = 0; + float tmp_avg2 = 0; + float tmp_avg3 = 0; + float tmp_avg4 = 0; +#endif + int real_count = 0; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_avg = vaddq_f32(tmp_avg, vld1q_f32(input_ptr + in_offset)); +#else + tmp_avg1 += *(input_ptr + in_offset); + tmp_avg2 += *(input_ptr + in_offset + 1); + tmp_avg3 += *(input_ptr + in_offset + 2); + tmp_avg4 += *(input_ptr + in_offset + 3); +#endif + ++real_count; + } + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + float32x4_t dup_count = vdupq_n_f32(real_count); + vst1q_f32(output_ptr + out_channel_offset, vdivq_f32(tmp_avg, dup_count)); +#else + *(output_ptr + out_channel_offset) = tmp_avg1 / (float)real_count; + *(output_ptr + out_channel_offset + 1) = tmp_avg2 / (float)real_count; + *(output_ptr + out_channel_offset + 2) = tmp_avg3 / (float)real_count; + *(output_ptr + out_channel_offset + 3) = tmp_avg4 / (float)real_count; +#endif + } // ic4-1 loop + int channel_s = (c4 - 1) * C4NUM; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + float tmp_avg = 0; + int real_count = 0; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg += *(input_ptr + in_offset); + ++real_count; + } + } // win_w loop + } // win_h loop + *(output_ptr + out_channel_offset) = tmp_avg / (float)real_count; + } // channel_res loop + } // real_cal_num loop + } // out_plane loop + } // out_batch loop +} + +void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + int thread_num = pooling_param->thread_num_; + int c4 = UP_DIV(channel, C4NUM); + // input channel is equal to output channel + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + index * channel; + for (int j = 0; j < c4 - 1; j++) { + int in_channel_offset = in_batch_offset + j * C4NUM; + int out_channel_offset = out_plane_offset + j * C4NUM; +#ifdef ENABLE_NEON + float32x4_t tmp_max = vdupq_n_f32(FLT_MIN); +#else + float tmp_max1 = FLT_MIN; + float tmp_max2 = FLT_MIN; + float tmp_max3 = FLT_MIN; + float tmp_max4 = FLT_MIN; +#endif + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_max = vmaxq_f32(tmp_max, vld1q_f32(input_ptr + in_offset)); +#else + tmp_max1 = fmax(tmp_max1, *(input_ptr + in_offset)); + tmp_max2 = fmax(tmp_max2, *(input_ptr + in_offset + 1)); + tmp_max3 = fmax(tmp_max3, *(input_ptr + in_offset + 2)); + tmp_max4 = fmax(tmp_max4, *(input_ptr + in_offset + 3)); +#endif + } + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + vst1q_f32(output_ptr + out_channel_offset, tmp_max); +#else + *(output_ptr + out_channel_offset) = tmp_max1; + *(output_ptr + out_channel_offset + 1) = tmp_max2; + *(output_ptr + out_channel_offset + 2) = tmp_max3; + *(output_ptr + out_channel_offset + 3) = tmp_max4; +#endif + } // ic4-1 loop + int channel_s = (c4 - 1) * C4NUM; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + float tmp_max = FLT_MIN; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = fmax(tmp_max, *(input_ptr + in_offset)); + } + } // win_w loop + } // win_h loop + *(output_ptr + out_channel_offset) = tmp_max; + } // channel_res loop + } // real_cal_num loop + } // out_plane loop + } // out_batch loop +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/pooling.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/pooling.h new file mode 100644 index 00000000000..d615d49f366 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/pooling.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_POOLING_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_POOLING_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct PoolingParameter { + OpParameter op_parameter_; + QuantArg **quant_args_; + bool global_; + bool max_pooling_; + bool avg_pooling_; + bool round_ceil_; + bool round_floor_; + int window_w_; + int window_h_; + int input_w_; + int input_h_; + int input_batch_; + int input_channel_; + int output_w_; + int output_h_; + int output_batch_; + int output_channel_; + int pad_u_; + int pad_d_; + int pad_l_; + int pad_r_; + int stride_w_; + int stride_h_; + int thread_num_; +}; + +void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id); + +void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_POOLING_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/range.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/range.cc new file mode 100644 index 00000000000..42523ee555b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/range.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/range.h" + +void Range(float *output_ptr, int start, int limit, int delta) { + size_t index = 0; + for (size_t i = start; i < limit; i += delta) { + output_ptr[index++] = (float)(i); + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/range.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/range.h new file mode 100644 index 00000000000..0f48a1e2fc8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/range.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_RANGE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_RANGE_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct RangeParameter { + OpParameter op_parameter_; + int dType_; + int start_; + int limit_; + int delta_; +}; + +void Range(float *output_ptr, int start, int limit, int delta); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_RANGE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/rank.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/rank.cc new file mode 100644 index 00000000000..ef7f4062357 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/rank.cc @@ -0,0 +1,22 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/rank.h" + +void Rank(float* output, int rank) { + output[0] = (float)(rank); +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/rank.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/rank.h new file mode 100644 index 00000000000..6933016063a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/rank.h @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_RANK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_RANK_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +void Rank(float* output, int rank); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_RANK_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reduce.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reduce.cc new file mode 100644 index 00000000000..00886102bea --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reduce.cc @@ -0,0 +1,146 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/runtime/kernel/arm/opclib/fp32/reduce.h" +#include "src/runtime/kernel/arm/opclib/errorcode.h" + +int ReduceMean(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num) { + if (src_data == nullptr || src_shape == nullptr || dst_data == nullptr) { + return OPCLIB_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float *outer_src = src_data + j * axis_size * inner_size; + float *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float *inner_src = outer_src + k; + float *inner_dst = outer_dst + k; + float tmp = 0.0f; + for (i = 0; i < axis_size; i++) { + tmp += inner_src[i * inner_size]; + } + *inner_dst = tmp / (float)axis_size; + } + } + return OPCLIB_OK; +} +int ReduceSum(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num) { + if (src_data == nullptr || src_shape == nullptr || dst_data == nullptr) { + return OPCLIB_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float *outer_src = src_data + j * axis_size * inner_size; + float *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float *inner_src = outer_src + k; + float *inner_dst = outer_dst + k; + float tmp = 0.0f; + for (i = 0; i < axis_size; i++) { + tmp += inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return OPCLIB_OK; +} +int ReduceMax(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num) { + if (src_data == nullptr || src_shape == nullptr || dst_data == nullptr) { + return OPCLIB_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float *outer_src = src_data + j * axis_size * inner_size; + float *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float *inner_src = outer_src + k; + float *inner_dst = outer_dst + k; + float tmp = -FLT_MAX; + for (i = 0; i < axis_size; i++) { + tmp = tmp > inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return OPCLIB_OK; +} +int ReduceMin(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num) { + if (src_data == nullptr || src_shape == nullptr || dst_data == nullptr) { + return OPCLIB_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float *outer_src = src_data + j * axis_size * inner_size; + float *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float *inner_src = outer_src + k; + float *inner_dst = outer_dst + k; + float tmp = FLT_MAX; + for (i = 0; i < axis_size; i++) { + tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return OPCLIB_OK; +} +int ReduceProd(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num) { + if (src_data == nullptr || src_shape == nullptr || dst_data == nullptr) { + return OPCLIB_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float *outer_src = src_data + j * axis_size * inner_size; + float *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float *inner_src = outer_src + k; + float *inner_dst = outer_dst + k; + float tmp = 1.0f; + for (i = 0; i < axis_size; i++) { + tmp *= inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return OPCLIB_OK; +} +int ReduceSumSquare(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num) { + if (src_data == nullptr || src_shape == nullptr || dst_data == nullptr) { + return OPCLIB_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float *outer_src = src_data + j * axis_size * inner_size; + float *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float *inner_src = outer_src + k; + float *inner_dst = outer_dst + k; + float tmp = 0.0f; + for (i = 0; i < axis_size; i++) { + tmp += inner_src[i * inner_size] * inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return OPCLIB_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reduce.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reduce.h new file mode 100644 index 00000000000..0172f43e709 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reduce.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_REDUCE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_REDUCE_H_ +#include "src/runtime/kernel/arm/opclib/op_base.h" +#define REDUCE_MAX_AXES_NUM 8 + +struct ReduceParameter { + OpParameter op_parameter_; + bool keep_dims_; + int axes_[REDUCE_MAX_AXES_NUM]; + int num_axes_; + int mode_; +}; + +int ReduceMean(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num); +int ReduceSum(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num); +int ReduceMax(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num); +int ReduceMin(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num); +int ReduceProd(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num); +int ReduceSumSquare(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + const int *src_shape, float *dst_data, const int tid, const int thread_num); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_REDUCE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reverse.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reverse.cc new file mode 100644 index 00000000000..5ba3b90d0fd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reverse.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/reverse.h" +#include +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/errorcode.h" + +int Reverse(const float *input, float *output, size_t elem_size, int *index) { + for (int i = 0; i < elem_size; i++) { + output[index[i]] = input[i]; + } + return OPCLIB_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reverse.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reverse.h new file mode 100644 index 00000000000..d84bedbcf8b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/reverse.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_REVERSE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_REVERSE_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/opclib/op_base.h" +#define REVERSE_SHAPE_MAX_SIZE 4 + +// For reverse. +struct ReverseParameter { + OpParameter op_parameter_; + int axis_[REVERSE_SHAPE_MAX_SIZE]; + int num_axis_; +}; + +int Reverse(const float *input, float *output, size_t elem_size, int *index); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_REVERSE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/slice.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/slice.cc new file mode 100644 index 00000000000..4161b648c9c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/slice.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/slice.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" + +void PadSliceParameterTo4D(SliceParameter *param) { + int32_t begin[DIMENSION_4D]; + int32_t end[DIMENSION_4D]; + int32_t slice_size[DIMENSION_4D]; + int32_t data_shape[DIMENSION_4D]; + for (int32_t i = 0; i < param->param_length_; ++i) { + begin[i] = param->begin_[i]; + end[i] = param->end_[i]; + slice_size[i] = param->size_[i]; + data_shape[i] = param->shape_[i]; + } + int32_t real_index = param->param_length_ - 1; + for (int32_t i = DIMENSION_4D - 1; i >= 0; --i) { + if (real_index >= 0) { + param->begin_[i] = begin[real_index]; + param->end_[i] = end[real_index]; + param->size_[i] = slice_size[real_index]; + param->shape_[i] = data_shape[real_index--]; + } else { + param->begin_[i] = 0; + param->end_[i] = 1; + param->size_[i] = 1; + param->shape_[i] = 1; + } + } + param->param_length_ = DIMENSION_4D; +} + +int DoSlice(const float *input, SliceParameter *param, float *output) { + if (param->param_length_ > DIMENSION_4D) { + return -1; + } + + for (int i = 0; i < param->param_length_; ++i) { + if (param->size_[i] < 0) { + param->size_[i] = param->shape_[i] - param->begin_[i]; + } + param->end_[i] = param->begin_[i] + param->size_[i]; + } + + if (param->param_length_ < DIMENSION_4D) { + PadSliceParameterTo4D(param); + } + size_t dim_offset[DIMENSION_4D - 1]; + dim_offset[2] = param->shape_[3]; + dim_offset[1] = dim_offset[2] * param->shape_[2]; + dim_offset[0] = dim_offset[1] * param->shape_[1]; + size_t output_index = 0; + for (int32_t dim0 = param->begin_[0]; dim0 < param->end_[0]; ++dim0) { + for (int32_t dim1 = param->begin_[1]; dim1 < param->end_[1]; ++dim1) { + for (int32_t dim2 = param->begin_[2]; dim2 < param->end_[2]; ++dim2) { + for (int32_t dim3 = param->begin_[3]; dim3 < param->end_[3]; ++dim3) { + output[output_index++] = *(input + dim0 * dim_offset[0] + + dim1 * dim_offset[1] + dim2 * dim_offset[2] + dim3); + } + } + } + } + return 0; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/slice.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/slice.h new file mode 100644 index 00000000000..8873101fcb2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/slice.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_SLICE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_SLICE_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" +#define SLICE_SHAPE_MAX_SIZE 4 + +struct SliceParameter { + OpParameter op_parameter_; + int32_t begin_[SLICE_SHAPE_MAX_SIZE]; + int32_t end_[SLICE_SHAPE_MAX_SIZE]; + int32_t size_[SLICE_SHAPE_MAX_SIZE]; + int32_t shape_[SLICE_SHAPE_MAX_SIZE]; + int32_t param_length_; +}; + +int DoSlice(const float *input, SliceParameter *param, float *output); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_SLICE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.cc new file mode 100644 index 00000000000..75f0b1b165c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/softmax.h" +#include + +// output = exp(input) / reduce_sum(exp(input), axis) +void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, SoftmaxParameter *parameter) { + int32_t axis = parameter->axis_; + int n_dim = parameter->n_dim_; + int ele_size = parameter->element_size_; + int *input_shape = parameter->input_shape_; + + for (int i = 0; i < ele_size; i++) { + output_ptr[i] = exp(input_ptr[i]); + } + int inner_size = 1, outter_size = 1; + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + sum_data[j] += output_ptr[inner_offset]; + } + } + } + + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] / sum_data[j]; + } + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.h new file mode 100644 index 00000000000..2f91d15836d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_SOFTMAX_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_SOFTMAX_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct SoftmaxParameter { + OpParameter op_parameter; + int32_t axis_; + int element_size_; + int n_dim_; + int input_shape_[4]; +}; + +void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, SoftmaxParameter *parameter); + + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_SOFTMAX_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/stack.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/stack.cc new file mode 100644 index 00000000000..10c0e6a6b97 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/stack.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/stack.h" +#include "src/runtime/kernel/arm/opclib/arithmetic_common.h" + +void DoStack(const float * const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output) { + size_t one_input_size = 1; + for (size_t i = 0; i < shape_size; ++i) { + one_input_size *= in_shape[i]; + } + int in_strides[shape_size]; + ComputeStrides(in_shape, in_strides, shape_size); + + size_t copy_num = axis > 0 ? in_strides[axis - 1] : one_input_size; + size_t copy_size = copy_num * sizeof(float); + size_t pre_axis_count = 1; + for (size_t i = 0; i < axis; ++i) { + pre_axis_count *= in_shape[i]; + } + size_t in_offset = 0; + size_t out_offset = 0; + for (size_t i = 0; i < pre_axis_count; ++i) { + for (size_t j = 0; j < input_num; ++j) { + memcpy(output + out_offset, inputs[j] + in_offset, copy_size); + out_offset += copy_num; + } + in_offset += copy_num; + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/stack.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/stack.h new file mode 100644 index 00000000000..18bd46a8d8e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/stack.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_STACK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_STACK_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct StackParameter { + OpParameter op_parameter_; + int32_t axis_; +}; + +void DoStack(const float * const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_STACK_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/strassen_matmul.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/strassen_matmul.cc new file mode 100644 index 00000000000..f1c31ba85dc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/strassen_matmul.cc @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/strassen_matmul.h" + +bool CheckRecursion(int row, int col, int deep, int max_recursion, int cur_recursion) { + if (cur_recursion >= max_recursion) { + return false; + } + + if (row % 2 != 0 || col % 2 != 0 || deep % 2 != 0) { + return false; + } + + int row2 = row / 2; + int col2 = col / 2; + int deep2 = deep / 2; + + float save_cost = row * col * 4 * deep * 4 * 2 + row * col * 4 - + 7 * (row2 * col2 * 4 * deep2 * 4 * 2 - row2 * col2 * 4) - 4 * (row2 * deep2 * 4 * 3) - + 4 * (deep2 * 4 * col2 * 4 * 3) - 7 * (row2 * col2 * 4 * 3); + + return (save_cost > 0.f); +} + +void GemmMatMulComm(const float *a_ptr, const float *b_ptr, float *dst_ptr, int row, int col, int deep, int b_stride, + int c_stride) { + int row4mod = row % 4; + int row4div = row / 4; + for (int r = 0; r < row; r++) { + int r4mod = r % 4; + int r4div = r / 4; + for (int c = 0; c < col * 4; c++) { + float value = 0; + int ic = c / 4 * c_stride + r * 4 + c % 4; + for (int d = 0; d < deep * 4; d++) { + int d4mod = d % 4; + int d4div = d / 4; + int a_stride = (r < (row4div * 4)) ? 4 : row4mod; + int ai = r4div * 4 * deep * 4 + d4div * a_stride * 4 + r4mod * 4 + d4mod; + int bi = c / 4 * b_stride + d * 4 + c % 4; + value = value + a_ptr[ai] * b_ptr[bi]; + } + dst_ptr[ic] = value; + } + } + return; +} + +void GemmMatMul(const float *a_ptr, const float *b_ptr, float *dst_ptr, int row, int col, int deep, int b_stride, + int c_stride) { + int row4mod = row % 4; + int row4div = row / 4; + + if (row4div > 0) { + GemmMatMulComm(a_ptr, b_ptr, dst_ptr, row4div * 4, col, deep, b_stride, c_stride); + } + + if (row4mod != 0) { + GemmMatMulComm(a_ptr + row4div * deep * 4 * 4, b_ptr, dst_ptr + row4div * 4 * 4, row4mod, col, deep, b_stride, + c_stride); + } + return; +} + +int RecursionMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param, + int max_recursion, int cur_recursion, float *tmp_a_ptr) { + size_t row2 = matmul_param->row_ / 2; + size_t deep2 = matmul_param->deep_ / 2; + size_t col2 = matmul_param->col_ / 2; + size_t a_stride = matmul_param->a_stride_; + size_t b_stride = matmul_param->b_stride_; + size_t c_stride = matmul_param->c_stride_; + + StrassenMatMulParameter *rec_matmul = new StrassenMatMulParameter(); + rec_matmul->row_ = row2; + rec_matmul->deep_ = deep2; + rec_matmul->col_ = col2; + + float *x_ptr = (float *)(malloc(row2 * MSMAX(deep2, col2) * FP32_STRASSEN_UINT * sizeof(float))); + if (x_ptr == nullptr) { + free(rec_matmul); + return OPCLIB_ERRCODE_STRASSEN_RECURSION_MALLOC; + } + float *y_ptr = (float *)(malloc(col2 * deep2 * FP32_STRASSEN_WEIGHT_UINT * sizeof(float))); + if (y_ptr == nullptr) { + free(x_ptr); + free(rec_matmul); + return OPCLIB_ERRCODE_STRASSEN_RECURSION_MALLOC; + } + size_t x_stride = row2 * FP32_STRASSEN_UINT; + size_t y_stride = deep2 * FP32_STRASSEN_WEIGHT_UINT; + + const float *a11 = a_ptr; + const float *a12 = a_ptr + deep2 * a_stride; + const float *a21 = a_ptr + row2 * FP32_STRASSEN_UINT; + const float *a22 = a_ptr + deep2 * a_stride + row2 * FP32_STRASSEN_UINT; + const float *b11 = b_ptr; + const float *b12 = b_ptr + col2 * b_stride; + const float *b21 = b_ptr + deep2 * FP32_STRASSEN_WEIGHT_UINT; + const float *b22 = b_ptr + col2 * b_stride + deep2 * FP32_STRASSEN_WEIGHT_UINT; + float *c11 = c_ptr; + float *c12 = c_ptr + col2 * c_stride; + float *c21 = c_ptr + row2 * FP32_STRASSEN_UINT; + float *c22 = c_ptr + col2 * c_stride + row2 * FP32_STRASSEN_UINT; + + /* S3 = A11 - A21 */ + MatrixSub(a11, a21, x_ptr, a_stride, a_stride, x_stride, row2, deep2); + + /* T3 = B22 - B12 */ + MatrixSub(b22, b12, y_ptr, b_stride, b_stride, y_stride, deep2 * 4, col2); + + /* P7 = S3T3 */ + rec_matmul->a_stride_ = x_stride; + rec_matmul->b_stride_ = y_stride; + rec_matmul->c_stride_ = c_stride; + StrassenMatmul(x_ptr, y_ptr, c21, rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); + + /* S1 = A21 + A22 */ + MatrixAdd(a21, a22, x_ptr, a_stride, a_stride, x_stride, row2, deep2); + + /* T1 = B12 - B11 */ + MatrixSub(b12, b11, y_ptr, b_stride, b_stride, y_stride, deep2 * 4, col2); + + /* P5 = S1T1 */ + StrassenMatmul(x_ptr, y_ptr, c22, rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); + + /* S2 = S1 - A11 */ + MatrixSub(x_ptr, a11, x_ptr, x_stride, a_stride, x_stride, row2, deep2); + + /* T2 = B22 - T1 */ + MatrixSub(b22, y_ptr, y_ptr, b_stride, y_stride, y_stride, deep2 * 4, col2); + + /* P6 = S2T2 */ + StrassenMatmul(x_ptr, y_ptr, c12, rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); + + /* S4 = A12 - S2 */ + MatrixSub(a12, x_ptr, x_ptr, a_stride, x_stride, x_stride, row2, deep2); + + /* P3 = S4B22 */ + rec_matmul->b_stride_ = b_stride; + StrassenMatmul(x_ptr, b22, c11, rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); + + /* P1 = A11B11 */ + rec_matmul->a_stride_ = a_stride; + rec_matmul->c_stride_ = row2 * FP32_STRASSEN_UINT; + StrassenMatmul(a11, b11, x_ptr, rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); + + /* U2 = P1 + P6 + U3 = U2 + P7 + U4 = U2 + P5 + U7 = U3 + P5 + U5 = U4 + P3 */ + MatrixMultiAdd(c11, c12, c21, c22, x_ptr, row2, col2, c_stride, x_stride); + + /* T4 = T2 - B21 */ + MatrixSub(y_ptr, b21, y_ptr, y_stride, b_stride, y_stride, deep2 * 4, col2); + + /* P4 = A22T4 */ + rec_matmul->b_stride_ = y_stride; + rec_matmul->c_stride_ = c_stride; + StrassenMatmul(a22, y_ptr, c11, rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); + + /* U6 = U3 - P4 */ + MatrixSub(c21, c11, c21, c_stride, c_stride, c_stride, row2, col2); + + /* P2 = A12B21 */ + rec_matmul->b_stride_ = b_stride; + StrassenMatmul(a12, b21, c11, rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); + + /* U1 = P1 + P2 */ + MatrixAdd(x_ptr, c11, c11, x_stride, c_stride, c_stride, row2, col2); + + free(x_ptr); + free(y_ptr); + free(rec_matmul); + return OPCLIB_OK; +} + +int CommonMatMul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param, + float *tmp_a_ptr) { + MatrixPack(a_ptr, tmp_a_ptr, matmul_param->row_, matmul_param->deep_, matmul_param->a_stride_); + GemmMatMul(tmp_a_ptr, b_ptr, c_ptr, matmul_param->row_, matmul_param->col_, matmul_param->deep_, + matmul_param->b_stride_, matmul_param->c_stride_); + return OPCLIB_OK; +} + +int StrassenMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param, + int max_recursion, int cur_recursion, float *tmp_a_ptr) { + if (CheckRecursion(matmul_param->row_, matmul_param->col_, matmul_param->deep_, cur_recursion, max_recursion)) { + return RecursionMatmul(a_ptr, b_ptr, c_ptr, matmul_param, max_recursion, cur_recursion, tmp_a_ptr); + } + return CommonMatMul(a_ptr, b_ptr, c_ptr, matmul_param, tmp_a_ptr); +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/strassen_matmul.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/strassen_matmul.h new file mode 100644 index 00000000000..8e8cd606087 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/strassen_matmul.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_STRASSEN_MATMUL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_STRASSEN_MATMUL_H_ + +#include +#include "src/runtime/kernel/arm/opclib/pack.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/errorcode.h" +#include "src/runtime/kernel/arm/opclib/strassen_matmul.h" +#include "src/runtime/kernel/arm/opclib/fp32/common_func.h" + +#define FP32_STRASSEN_UINT C4NUM +#define FP32_STRASSEN_WEIGHT_UINT (C4NUM * C4NUM) +#define FP32_STRASSEN_MAX_RECURSION 5 + +int RecursionMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param, + int max_recursion, int, float *tmp_a_ptr); +int CommonMatMul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *Matmul_param, + float *tmp_a_ptr); + +int StrassenMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param, + int max_recursion, int cur_recursion, float *tmp_a_ptr); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_STRASSEN_MATMUL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/strided_slice.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/strided_slice.cc new file mode 100644 index 00000000000..7c2cb8e0f18 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/strided_slice.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/strided_slice.h" +#include "src/runtime/kernel/arm/opclib/errorcode.h" + +void PadStridedSliceParameterTo4D(StridedSliceParameter *param) { + int32_t begins[DIMENSION_4D]; + int32_t ends[DIMENSION_4D]; + int32_t strides[DIMENSION_4D]; + int32_t input_shape[DIMENSION_4D]; + for (int32_t i = 0; i < param->num_axes_; ++i) { + begins[i] = param->begins_[i]; + ends[i] = param->ends_[i]; + strides[i] = param->strides_[i]; + input_shape[i] = param->in_shape_[i]; + } + int32_t real_index = param->num_axes_ - 1; + for (int32_t i = DIMENSION_4D - 1; i >= 0; --i) { + if (real_index >= 0) { + param->begins_[i] = begins[real_index]; + param->ends_[i] = ends[real_index]; + param->strides_[i] = strides[real_index]; + param->in_shape_[i] = input_shape[real_index--]; + } else { + param->begins_[i] = 0; + param->ends_[i] = 1; + param->strides_[i] = 1; + param->in_shape_[i] = 1; + } + } + param->num_axes_ = DIMENSION_4D; +} + +int DoStridedSlice(const float *in_data, float *out_data, StridedSliceParameter *param) { + if (in_data == nullptr || out_data == nullptr || param == nullptr) { + return OPCLIB_NULL_PTR; + } + if (param->num_axes_ > DIMENSION_4D) { + return OPCLIB_PARAM_INVALID; + } + + int *begins = param->begins_; + int *ends = param->ends_; + int *strides = param->strides_; + int *in_shape = param->in_shape_; + + if (param->num_axes_ < DIMENSION_4D) { + PadStridedSliceParameterTo4D(param); + } + + size_t dim_offset[DIMENSION_4D - 1]; + dim_offset[2] = in_shape[3]; + dim_offset[1] = dim_offset[2] * in_shape[2]; + dim_offset[0] = dim_offset[1] * in_shape[1]; + size_t output_index = 0; + for (int32_t dim0 = begins[0]; dim0 < ends[0]; dim0 += strides[0]) { + for (int32_t dim1 = begins[1]; dim1 < ends[1]; dim1 += strides[1]) { + for (int32_t dim2 = begins[2]; dim2 < ends[2]; dim2 += strides[2]) { + for (int32_t dim3 = begins[3]; dim3 < ends[3]; dim3 += strides[3]) { + out_data[output_index++] = + *(in_data + dim0 * dim_offset[0] + dim1 * dim_offset[1] + dim2 * dim_offset[2] + dim3); + } + } + } + } + + return OPCLIB_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/strided_slice.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/strided_slice.h new file mode 100644 index 00000000000..1339ca391d7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/strided_slice.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_OPCLIB_FP32_STRIDED_SLICE_H_ +#define MINDSPORE_LITE_SRC_BACKEND_ARM_OPCLIB_FP32_STRIDED_SLICE_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct StridedSliceParameter { + OpParameter op_parameter_; + int begins_[8] = {0}; + int ends_[8] = {0}; + int strides_[8] = {1}; + int isScale; + int num_axes_; + int in_shape_[8]; +}; + +int DoStridedSlice(const float *inputs, float *output, StridedSliceParameter * param); +#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_OPCLIB_FP32_STRIDED_SLICE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/unsqueeze.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/unsqueeze.cc new file mode 100644 index 00000000000..f90ab24ed72 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/unsqueeze.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/unsqueeze.h" +#include +#include "src/runtime/kernel/arm/opclib/errorcode.h" + +int Unsqueeze(float *input_ptr, float *output_ptr, size_t data_size) { + memcpy(output_ptr, input_ptr, data_size); + return OPCLIB_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/unsqueeze.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/unsqueeze.h new file mode 100644 index 00000000000..1efadcd9146 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/unsqueeze.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_UNSQUEEZE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_UNSQUEEZE_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +#define UNSQUEEZE_DIMS_MAX_SIZE 4 + +struct UnsqueezeParameter { + OpParameter op_parameter_; + int dims_[UNSQUEEZE_DIMS_MAX_SIZE]; + int num_dim_; +}; + +int Unsqueeze(float *input_ptr, float *output_ptr, size_t data_size); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_UNSQUEEZE_H_ + + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fused_batchnorm.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fused_batchnorm.cc new file mode 100644 index 00000000000..2034cc138f4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fused_batchnorm.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fused_batchnorm.h" + +void FusedBatchNorm(const float *input_ptr, const float *scale_ptr, const float *offest_ptr, const float *mean_ptr, + const float *variance_ptr, int *input_shapes, float epsilon, float *output_ptr) { + int channel = input_shapes[3]; + int units = 1; + for (int i = 0; i < 3; i++) { + units *= input_shapes[i]; + } + for (int c = 0; c < input_shapes[3]; c++) { + auto variance_sqrt = sqrt(variance_ptr[c] + epsilon); + for (int u = 0; u < units; u++) { + output_ptr[u * channel + c] = + (input_ptr[u * channel + c] - mean_ptr[c]) / variance_sqrt * scale_ptr[c] + offest_ptr[c]; + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fused_batchnorm.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fused_batchnorm.h new file mode 100644 index 00000000000..f81780106e3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fused_batchnorm.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FUSED_BATCHNORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FUSED_BATCHNORM_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct FusedBatchNormParameter { + OpParameter op_parameter_; + float epsilon_; +}; + + +void FusedBatchNorm(const float *input_ptr, const float *scale_ptr, const float *offest_ptr, const float *mean_ptr, + const float *variance_ptr, int *input_shapes, float epsilon, float *output_ptr); + + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FUSED_BATCHNORM_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/concat_int8.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/concat_int8.cc new file mode 100644 index 00000000000..07627229aff --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/concat_int8.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/int8/concat_int8.h" +#include + +void Concat(int8_t **inputs, int8_t *output_ptr, ConcatQuantArg *quant_concat_parm, int axis) { + float output_scale = quant_concat_parm->out_quant_args_.scale_; + float output_inverse_scale = 1.f / output_scale; + int input_num = quant_concat_parm->input_num_; + int *output_shape = quant_concat_parm->output_shape_; + int output_dim = quant_concat_parm->output_dim_; + QuantArg *input_quant = quant_concat_parm->in_quant_args_; + int output_zp = quant_concat_parm->out_quant_args_.zp_; + + int before_axis_size = 1; + for (int i = 0; i < axis; i++) { + before_axis_size *= output_shape[i]; + } + + int after_axis_size = 1; + for (size_t i = axis + 1; i < output_dim; i++) { + after_axis_size *= output_shape[i]; + } + + for (int k = 0; k < before_axis_size; k++) { + for (int i = 0; i < input_num; i++) { + int *input_shape = quant_concat_parm->input_shapes_[i]; + int copy_size = input_shape[axis] * after_axis_size; + int8_t *input_ptr = inputs[i] + k * copy_size; + if (input_quant[i].scale_ == output_scale && input_quant[i].zp_ == output_zp) { + memcpy(output_ptr, input_ptr, copy_size); + } else { + float scale = input_quant[i].scale_ * output_inverse_scale; + float bias = -input_quant[i].zp_ * scale; + for (int j = 0; j < copy_size; j++) { + int32_t output_tmp = round(input_ptr[j] * scale + bias) + output_zp; + if (output_tmp > 127) { + output_ptr[j] = 127; + } else if (output_tmp < -128) { + output_ptr[j] = -128; + } else { + output_ptr[j] = (int8_t)output_tmp; + } + } + } + output_ptr += copy_size; + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/concat_int8.h b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/concat_int8.h new file mode 100644 index 00000000000..32e66c807e7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/concat_int8.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_CONCAT_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_CONCAT_INT8_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +void Concat(int8_t **inputs, int8_t *output_ptr, ConcatQuantArg *quant_concat_parm, int axis); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_CONCAT_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_depthwise_int8.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_depthwise_int8.cc new file mode 100644 index 00000000000..02bba0ae388 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_depthwise_int8.cc @@ -0,0 +1,322 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/int8/conv_depthwise_int8.h" +#include +#include "src/runtime/kernel/arm/opclib/quantization/fixed_point.h" + +/*conv depthwise int8 begin*/ +void DepthwiseBorderPixelInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height, + int width, int in_kh_step, int in_kw_step, int kernel_w, int out_multiplier, + int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max) { + int tmp_buffer[C4NUM]; + for (int i = 0; i < C4NUM; i++) { + tmp_buffer[i] = 0; + } + const int16_t *src_kh = src; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + const int16_t *src_kw = src_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + for (int c = 0; c < C4NUM; c++) { + tmp_buffer[c] += src_kw[c] * weight_kw[c]; + } + src_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + for (int c = 0; c < C4NUM; c++) { + tmp_buffer[c] += bias[c]; + tmp_buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift), out_multiplier), -right_shift); + tmp_buffer[c] += out_zp; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); + dst[c] = static_cast(tmp_buffer[c]); + } +} + +void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int top, + int bottom, int left, int right, const ConvParameter *conv_param, + const SlidingWindowParam *sliding) { + int8_t *dst_h = dst + top * sliding->out_h_step_; + for (int oh = top; oh < bottom; oh++) { + int ih = oh * conv_param->stride_h_ - conv_param->pad_h_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const int16_t *src_h = src + ih * sliding->in_h_step_; + + int8_t *dst_kernel = dst_h + left * sliding->block_channel_; + for (int ow = left; ow < right; ow++) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_w_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const int16_t *src_w = src_h + iw * sliding->block_channel_; + + const int16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; + + DepthwiseBorderPixelInt8( + dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, sliding->in_kh_step_, + sliding->in_kw_step_, conv_param->kernel_w_, conv_param->conv_quant_arg_.quant_multiplier_[0], + conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0], + conv_param->conv_quant_arg_.quant_args_[2][0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], + conv_param->conv_quant_arg_.out_act_max_[0]); + + dst_kernel += sliding->block_channel_; + } // width loop + dst_h += sliding->out_h_step_; + } // height loop +} + +void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height, + int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, + int in_sw_step, int in_kh_step, int in_kw_step, int out_multiplier, int left_shift, + int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max) { + int tmp_buffer[C4NUM]; + int8_t *dst_h = dst; + const int16_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + int8_t *dst_w = dst_h; + const int16_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + const int16_t *src_kh = src_w; + const int16_t *weight_kh = weight; + + for (int i = 0; i < C4NUM; i++) { + tmp_buffer[i] = 0; + } + for (int kh = 0; kh < kernel_h; kh++) { + const int16_t *src_kw = src_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C4NUM; c++) { + tmp_buffer[c] += src_kw[c] * weight_kw[c]; + } + src_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + // add bias relu + for (int c = 0; c < C4NUM; c++) { + tmp_buffer[c] += bias[c]; + tmp_buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift), out_multiplier), + -right_shift); + tmp_buffer[c] += out_zp; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); + dst_w[c] = static_cast(tmp_buffer[c]); + } + dst_w += block_channel; + src_w += in_sw_step; + } // dst_width loop + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} + +void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) { + const int16_t *src = input_data; + int8_t *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const int16_t *src_data = src + oc * C4NUM; + int8_t *dst_data = dst + oc * C4NUM; + const int16_t *weight = weight_data + oc * sliding->kernel_step_; + const int32_t *bias = bias_data + oc * C4NUM; + DepthwiseBorderInt8(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, + sliding); + DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, + conv_param->output_w_, conv_param, sliding); + DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, + conv_param, sliding); + DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->output_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_; + int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_; + const int16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * C4NUM; + int8_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * C4NUM; + + DepthwiseCenterInt8( + out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, + sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, + conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], + conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.quant_args_[2][0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); + } + } // output C4 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nc4hwc4 +} +/*conv depthwise int8 end*/ + +/*deconv depthwise int8 begin*/ +void DeconvDepthwiseBorderPixelInt8(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width, + int in_kh_step, int in_kw_step, int kernel_w) { + int32_t *dst_kh = dst; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + int32_t *dst_kw = dst_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src[c] * weight_kw[c]; + } + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop +} + +void DeconvDepthwiseBorderInt8(int32_t *dst, const int16_t *src, const int16_t *weight, int top, int bottom, int left, + int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + const int16_t *src_h = src + top * sliding->out_h_step_; + for (int ih = top; ih < bottom; ih++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_h_; + int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int32_t *dst_h = dst + oh * sliding->in_h_step_; + + const int16_t *src_kernel = src_h + left * sliding->block_channel_; + for (int iw = left; iw < right; iw++) { + int ow = iw * conv_param->stride_w_ - conv_param->pad_w_; + int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + int32_t *dst_w = dst_h + ow * C4NUM; + + const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; + int32_t *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + + DeconvDepthwiseBorderPixelInt8(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_); + src_kernel += sliding->block_channel_; + } // width loop + src_h += sliding->out_h_step_; + } // height loop +} + +void DeconvDepthwiseCenterInt8(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width, + int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, + int in_sw_step, int in_kh_step, int in_kw_step) { + int32_t *dst_h = dst; + const int16_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + int32_t *dst_w = dst_h; + const int16_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + int32_t *dst_kh = dst_w; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < kernel_h; kh++) { + int32_t *dst_kw = dst_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src_w[c] * weight_kw[c]; + } + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} + +void DeconvDepthwisePostFuncInt8(int8_t *dst, int32_t *output_buffer, const int32_t *bias, int block_channel, + const ConvParameter *conv_param, int out_multiplier, int left_shift, int right_shift, + int32_t out_zp, int32_t acc_min, int32_t acc_max) { + int8_t *dst_k = dst; + int32_t *buffer_k = output_buffer; + for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) { + for (int c = 0; c < C4NUM; c++) { + buffer_k[c] += bias[c]; + buffer_k[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(buffer_k[c] * (1 << (unsigned int)left_shift), out_multiplier), -right_shift); + buffer_k[c] += out_zp; + buffer_k[c] = MSMAX(buffer_k[c], acc_min); + buffer_k[c] = MSMIN(buffer_k[c], acc_max); + dst_k[c] = static_cast(buffer_k[c]); + } + dst_k += block_channel; + buffer_k += C4NUM; + } +} + +void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id) { + const int16_t *src = input_data; + int8_t *dst = output_data; + int buffer_size = conv_param->output_h_ * conv_param->output_w_ * C4NUM; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + memset(output_buffer, 0, buffer_size * sizeof(int32_t)); + const int16_t *src_data = src + oc * C4NUM; + const int16_t *weight = weight_data + oc * sliding->kernel_step_; + const int32_t *bias = bias_data + oc * C4NUM; + int8_t *dst_data = dst + oc * C4NUM; + DeconvDepthwiseBorderInt8(output_buffer, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, + sliding); + DeconvDepthwiseBorderInt8(output_buffer, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, + conv_param->input_w_, conv_param, sliding); + DeconvDepthwiseBorderInt8(output_buffer, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, + conv_param, sliding); + DeconvDepthwiseBorderInt8(output_buffer, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->input_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_; + int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_; + int32_t *out_t = output_buffer + oh_h_start * sliding->in_h_step_ + oh_w_start * C4NUM; + const int16_t *in_t = + src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; + + DeconvDepthwiseCenterInt8(out_t, in_t, weight, sliding->bottom_ - sliding->top_, + sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, + sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, + sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_); + } + DeconvDepthwisePostFuncInt8( + dst_data, output_buffer, bias, sliding->block_channel_, conv_param, + conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], + conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.quant_args_[2][0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); + } // output C4 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nc4hwc4 +} +/*deconv depthwise int8 end*/ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_depthwise_int8.h b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_depthwise_int8.h new file mode 100644 index 00000000000..e88c6b85295 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_depthwise_int8.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_CONV_DEPTHWISE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_CONV_DEPTHWISE_H_ + +#include "src/runtime/kernel/arm/opclib/conv_parameter.h" +#include "src/runtime/kernel/arm/opclib/fp32/conv_depthwise.h" + +void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); + +void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_CONV_DEPTHWISE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_int8.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_int8.cc new file mode 100644 index 00000000000..e3573498d0c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_int8.cc @@ -0,0 +1,338 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/int8/conv_int8.h" +#include +#include "src/runtime/kernel/arm/opclib/winograd_transform.h" + +extern "C" { +#ifdef ENABLE_ARM +void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *weight, size_t ksize, size_t ic8, + size_t oc4, size_t offset); + +#ifdef ENABLE_ARM64 +// void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t +// ksize, +// size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, +// size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, size_t +// shift_before, size_t shift_after); +void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize, + size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min, + size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before, + size_t shift_after); +#elif defined(ENABLE_ARM32) +void IndirectGemmInt8_2x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize, + size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min, + size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before, + size_t shift_after); +#endif +#endif +} + +void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const int8_t *weight, const int32_t *bias, + int ic4, size_t kernel_plane, size_t output_channel, const int32_t *input_sum, + ConvParameter *conv_param) { + int32_t shift_before = conv_param->conv_quant_arg_.left_shift_[0]; + int32_t shift_after = conv_param->conv_quant_arg_.right_shift_[0]; + int32_t out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0]; + int32_t out_zp = conv_param->conv_quant_arg_.quant_args_[2][0].zp_; + int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0]; +#ifdef __aarch64__ + IndirectGemmInt8_4x4(dst, src, weight, bias, kernel_plane, ic4, output_channel, output_channel * sizeof(int8_t), + input_sum, act_min, act_max, out_zp, out_multiplier, shift_before, shift_after); +#else + int tile_num = 4; + int plane_c4 = UP_DIV(kernel_plane, C4NUM); + for (int oc = 0; oc < output_channel; oc++) { + int oc4_block = oc / C4NUM; + int oc4_res = oc % C4NUM; + int weight_oc4_offset = oc4_block * C4NUM * plane_c4 * C4NUM * ic4 * C4NUM + oc4_res * C4NUM * C4NUM; + int dst_oc_offset = oc; + for (int n = 0; n < tile_num; n++) { + int src_tile_offset = n * C4NUM * C4NUM; + int dst_tile_offset = dst_oc_offset + n * output_channel; + + for (int b = 0; b < kernel_plane; b++) { + int plane_c4_block = b / C4NUM; + int plane_c4_res = b % C4NUM; + int src_plane_offset = src_tile_offset + plane_c4_block * tile_num * C4NUM * ic4 * C4NUM + plane_c4_res * C4NUM; + int weight_plane_offset = + weight_oc4_offset + plane_c4_block * tile_num * C4NUM * ic4 * C4NUM + plane_c4_res * C4NUM; + for (int i = 0; i < ic4; i++) { + int src_ic4_offset = src_plane_offset + i * tile_num * C4NUM * C4NUM; + int weight_ic4_offset = weight_plane_offset + i * C4NUM * C4NUM * C4NUM; + for (int j = 0; j < C4NUM; j++) { + int weight_ic_offset = weight_ic4_offset + j; + tmp_dst[dst_tile_offset] += weight[weight_ic_offset] * src[src_ic4_offset + j]; + } // in c4num loop + } // ic4 loop + } // kernel_plane loop + tmp_dst[dst_tile_offset] -= input_sum[n]; + int result = tmp_dst[dst_tile_offset] + bias[oc]; + result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before), out_multiplier), -shift_after); + result += out_zp; + result = result > act_min ? result : act_min; + result = result < act_max ? result : act_max; + dst[dst_tile_offset] = (int8_t)result; + } // tile_num loop + } // output_channel loop +#endif +} + +void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const int8_t *weight, const int32_t *bias, + int ic4, size_t kernel_plane, size_t output_channel, const int32_t *input_sum, + ConvParameter *conv_param, GEMM_FUNC gemm_func) { + int32_t shift_before = conv_param->conv_quant_arg_.left_shift_[0]; + int32_t shift_after = conv_param->conv_quant_arg_.right_shift_[0]; + int32_t out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0]; + int32_t out_zp = conv_param->conv_quant_arg_.quant_args_[2][0].zp_; + int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0]; + if (gemm_func != nullptr) { +#ifdef __aarch64__ + gemm_func(dst, src, weight, bias, kernel_plane, ic4, output_channel, output_channel * sizeof(int8_t), input_sum, + act_min, act_max, out_zp, out_multiplier, shift_before, shift_after); +#endif + } else { + int tile_num = 24; + for (int oc = 0; oc < output_channel; oc++) { + int oc4_block = oc / C4NUM; + int oc4_res = oc % C4NUM; + int weight_oc4_offset = oc4_block * C4NUM * kernel_plane * ic4 * C4NUM + oc4_res * C4NUM; + int dst_oc_offset = oc; + for (int n = 0; n < tile_num; n++) { + int src_tile_offset = n * C4NUM; + int dst_tile_offset = dst_oc_offset + n * output_channel; + + for (int b = 0; b < kernel_plane; b++) { + int src_plane_offset = src_tile_offset + b * tile_num * ic4 * C4NUM; + int weight_plane_offset = weight_oc4_offset + b * C4NUM * ic4 * C4NUM; + for (int i = 0; i < ic4; i++) { + int src_ic4_offset = src_plane_offset + i * tile_num * C4NUM; + int weight_ic4_offset = weight_plane_offset + i * C4NUM * C4NUM; + for (int j = 0; j < C4NUM; j++) { + int weight_ic_offset = weight_ic4_offset + j; + tmp_dst[dst_tile_offset] += weight[weight_ic_offset] * src[src_ic4_offset + j]; + } // in c4num loop + } // ic4 loop + } // kernel_plane loop + tmp_dst[dst_tile_offset] -= input_sum[n]; + int result = tmp_dst[dst_tile_offset] + bias[oc]; + result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before), out_multiplier), -shift_after); + result += out_zp; + result = result > act_min ? result : act_min; + result = result < act_max ? result : act_max; + dst[dst_tile_offset] = (int8_t)result; + } // tile_num loop + } // output_channel loop + } +} + +void Conv3x3Uint8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) { + int oc4 = UP_DIV(oc, C4NUM); + int input_unit_square = 16; +#ifdef ENABLE_ARM + IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, oc4 * 4 * 16 * sizeof(int32_t)); +#else + for (int c = 0; c < oc4; c++) { + int filter_oc_offset = c * input_unit_square * ic8 * C8NUM * C4NUM; + int dst_oc_offset = c * input_unit_square * C4NUM; + for (int n = 0; n < real_cal_num; n++) { + int src_tile_offset = n * C8NUM; + int dst_tile_offset = dst_oc_offset + n * oc4 * C4NUM * input_unit_square; + for (int i = 0; i < 4; i++) { + int filter_h_offset = filter_oc_offset + i * 4 * ic8 * C8NUM * C4NUM; + int src_h_offset = src_tile_offset + i * C8NUM * ic8 * C8NUM * C4NUM; + int dst_h_offset = dst_tile_offset + i * 4 * 4; + for (int m = 0; m < 4; m++) { + int filter_w_offset = filter_h_offset + m * 4 * C8NUM * ic8; + int src_w_offset = src_h_offset + m * 8 * ic8 * C8NUM; + int dst_w_offset = dst_h_offset + m * C4NUM; + + int32_t acc[4] = {0}; + for (int z = 0; z < 4; z++) { + int filter_offset = filter_w_offset + z; + for (int j = 0; j < ic8; j++) { + int filter_c8_offset = filter_offset + j * 4 * 8; + int src_c8_offset = src_w_offset + j * 8 * 8; + + for (int k = 0; k < 8; k++) { + const int16_t *w_ptr = weight + filter_c8_offset + k * 4; + const int16_t *input_ptr = src + src_c8_offset + k; + acc[z] += w_ptr[0] * input_ptr[0]; + } + } + (dst + dst_w_offset + z)[0] = acc[z]; + } + } + } + } + } +#endif +} + +// int8 conv common +void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, const int32_t *bias_data, + int32_t *tmp_dst, int8_t *tmp_out, int8_t *output_data, int32_t *input_sum, int task_id, + ConvParameter *conv_param) { + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + int out_channel = conv_param->output_channel_; + int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_; + + int tile_n = 4; + int thread_count = conv_param->thread_num_; + int output_count = out_h * out_w; + int output_tile_count = UP_DIV(output_count, tile_n); + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int unit_size = kernel_plane * ic4 * C4NUM; + int packed_input_size = output_tile_count * tile_n * unit_size; + + for (int b = 0; b < in_batch; b++) { + int in_batch_offset = b * in_channel * in_h * in_w; + int out_batch_offset = b * out_channel * out_h * out_w; + int gemm_in_batch_offset = b * packed_input_size; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + int start_index = thread_id * tile_n; + int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; + int32_t *tmp_input_sum = input_sum + thread_id * tile_n; + int8_t *gemm_input = packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset; + // clear tmp buffer before compute + memset(gemm_input, (int8_t)input_zp, unit_size * tile_n); + int out_offset = thread_id * tile_n * out_channel + out_batch_offset; + // todo + size_t tmp_dst_size = thread_count * tile_n * conv_param->output_channel_ * sizeof(int32_t); + memset(tmp_dst, 0, tmp_dst_size); + + Im2ColPackUnitInt8(input_data + in_batch_offset, gemm_input, real_cal_num, start_index, input_sum, conv_param); + if (real_cal_num == tile_n) { + int8_t *gemm_output = output_data + out_offset; + IndirectGemmInt8(gemm_output, tmp_dst, gemm_input, packed_weight, bias_data, ic4, kernel_plane, out_channel, + input_sum, conv_param); + } else { + // res part + IndirectGemmInt8(tmp_out, tmp_dst, gemm_input, packed_weight, bias_data, ic4, kernel_plane, out_channel, + input_sum, conv_param); + memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel); + } + } + } +} + +void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, const int32_t *bias_data, + int32_t *tmp_dst, int8_t *tmp_out, int8_t *output_data, int32_t *input_sum, int task_id, + ConvParameter *conv_param, GEMM_FUNC gemm_func) { + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + int out_channel = conv_param->output_channel_; + int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_; + + // todo + int tile_n = 24; + int thread_count = conv_param->thread_num_; + int output_count = out_h * out_w; + int output_tile_count = UP_DIV(output_count, tile_n); + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int unit_size = kernel_plane * ic4 * C4NUM; + int packed_input_size = output_tile_count * tile_n * unit_size; + + for (int b = 0; b < in_batch; b++) { + int in_batch_offset = b * in_channel * in_h * in_w; + int out_batch_offset = b * out_channel * out_h * out_w; + int gemm_in_batch_offset = b * packed_input_size; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + int start_index = thread_id * tile_n; + int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; + // todo + int32_t *tmp_input_sum = input_sum + thread_id * tile_n; + int8_t *gemm_input = packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset; + // clear tmp buffer before compute + memset(gemm_input, (int8_t)input_zp, unit_size * tile_n); + int out_offset = thread_id * tile_n * out_channel + out_batch_offset; + // todo + size_t tmp_dst_size = thread_count * tile_n * conv_param->output_channel_ * sizeof(int32_t); + memset(tmp_dst, 0, tmp_dst_size); + + Im2ColPackUnitInt8Opt(input_data + in_batch_offset, gemm_input, real_cal_num, start_index, input_sum, conv_param); + if (real_cal_num == tile_n) { + int8_t *gemm_output = output_data + out_offset; + IndirectGemmInt8Opt(gemm_output, tmp_dst, gemm_input, packed_weight, bias_data, ic4, kernel_plane, out_channel, + input_sum, conv_param, gemm_func); + } else { + // res part + IndirectGemmInt8Opt(tmp_out, tmp_dst, gemm_input, packed_weight, bias_data, ic4, kernel_plane, out_channel, + input_sum, conv_param, gemm_func); + memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel); + } + } + } +} + +// int8 convolution 3x3 +void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, + int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, + int task_id, ConvParameter *conv_param) { + // todo + int thread_count = conv_param->thread_num_; + int ic8 = UP_DIV(conv_param->input_channel_, C8NUM); + int output_batch = conv_param->output_batch_; + int output_channel = conv_param->output_channel_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); + int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); + int output_count = out_w_block * out_h_block; + int output_tile_count = UP_DIV(output_count, TILE_NUM); + + int input_batch = conv_param->input_batch_; + for (int batch = 0; batch < input_batch; batch++) { + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { + int start_index = thread_id * TILE_NUM; + int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; + + Conv3x3Uint8InputTransform(input_data, tile_buffer, block_unit_buffer, start_index, real_cal_num, out_w_block, + conv_param); + + Conv3x3Uint8Gemm(tmp_dst_buffer, tile_buffer, transed_weight, output_channel, ic8, real_cal_num); + + Conv3x3Uint8OutputTransform(tmp_dst_buffer, tmp_out, bias_data, start_index, real_cal_num, out_w_block, + conv_param); + } + } + + // get real output + for (int batch = 0; batch < output_batch; batch++) { + // int batch_size = batch * output_channel * output_h * output_w; + C4UnpackToHwcInt8(tmp_out, output_data, output_channel, output_h, output_w); + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_int8.h b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_int8.h new file mode 100644 index 00000000000..bbf8424f1d4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_int8.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_CONV_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_CONV_INT8_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/opclib/pack.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/common_func.h" +#include "src/runtime/kernel/arm/opclib/conv_parameter.h" +#include "src/runtime/kernel/arm/opclib/winograd_utils.h" +#include "src/runtime/kernel/arm/opclib/quantization/quantize.h" + +typedef void (*GEMM_FUNC)(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t ksize, + size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, size_t act_min, + size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before, + size_t shift_after); + +void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const int8_t *weight, const int32_t *bias, + int ic4, size_t kernel_plane, size_t output_channel, const int32_t *input_sum, + ConvParameter *conv_param); + +void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const int8_t *weight, const int32_t *bias, + int ic4, size_t kernel_plane, size_t output_channel, const int32_t *input_sum, + ConvParameter *conv_param, GEMM_FUNC gemm_func); + +// int8 conv common +void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, const int32_t *bias_data, + int32_t *tmp_dst, int8_t *tmp_out, int8_t *output_data, int32_t *input_sum, int task_id, + ConvParameter *conv_param); + +void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, const int32_t *bias_data, + int32_t *tmp_dst, int8_t *tmp_out, int8_t *output_data, int32_t *input_sum, int task_id, + ConvParameter *conv_param, GEMM_FUNC gemm_func); + +// int8 convolution 3x3 +void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, + int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, + int task_id, ConvParameter *conv_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_CONV_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/deconv.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/deconv.cc new file mode 100644 index 00000000000..52587034ee9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/deconv.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/int8/deconv.h" + +int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, size_t row8, size_t col8, size_t deep, + ConvParameter *conv_param) { + MatMulInt8(input, weight, output, row8, col8, deep, conv_param->conv_quant_arg_.quant_args_[0][0].zp_, + conv_param->conv_quant_arg_.quant_args_[1][0].zp_); + return OPCLIB_OK; +} + +int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, + ConvParameter *conv_param) { + int oc8 = UP_DIV(output_channel, C8NUM); + size_t input_plane = conv_param->input_w_ * conv_param->input_h_; + size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + size_t output_plane = conv_param->output_w_ * conv_param->output_h_; + + for (int c = 0; c < oc8; c++) { + int32_t *dst_ptr = tmp + c * output_plane * C8NUM; + const int32_t *src_ptr = src + c * input_plane * kernel_plane * C8NUM; + memset(dst_ptr, 0, output_plane * C8NUM * sizeof(int32_t)); + + for (int ih = 0; ih < conv_param->input_h_; ih++) { + for (int iw = 0; iw < conv_param->input_w_; iw++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_h_; + int ow = iw * conv_param->stride_w_ - conv_param->pad_w_; + + int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + for (int kh = kh_start; kh < kh_end; kh++) { + for (int kw = kw_start; kw < kw_end; kw++) { + int src_index = ih * conv_param->input_w_ * C8NUM + iw * C8NUM + + kh * input_plane * conv_param->kernel_w_ * C8NUM + kw * input_plane * C8NUM; + int dst_index = oh * conv_param->output_w_ * C8NUM + ow * C8NUM + + kh * conv_param->dilation_h_ * conv_param->output_w_ * C8NUM + + kw * conv_param->dilation_w_ * C8NUM; + for (int i = 0; i < C8NUM; i++) { + dst_ptr[dst_index + i] += src_ptr[src_index + i]; + } + } /*kw*/ + } /*kh*/ + } /*iw*/ + } /*ih*/ + } /*oc8*/ + + PostFuncInt8(tmp, bias, out, output_channel, output_plane, UP_ROUND(output_plane, 8), + conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], + conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.quant_args_[2][0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); + return OPCLIB_OK; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/deconv.h b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/deconv.h new file mode 100644 index 00000000000..ccda22571ea --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/deconv.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_DECONV_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_DECONV_H_ + +#include +#include "src/runtime/kernel/arm/opclib/pack.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/errorcode.h" +#include "src/runtime/kernel/arm/opclib/conv_parameter.h" +#include "src/runtime/kernel/arm/opclib/common_func.h" +#include "src/runtime/kernel/arm/opclib/int8/matmul.h" + +int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, size_t row8, size_t col8, size_t deep, + ConvParameter *conv_param); + +int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, + ConvParameter *conv_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_DECONV_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.cc new file mode 100644 index 00000000000..aa5de959de0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/int8/matmul.h" +#include +#include "src/runtime/kernel/arm/opclib/quantization/fixed_point.h" + +void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + for (int r = 0; r < row; r++) { + int rd8 = r / 8; + int rm8 = r % 8; + for (int c = 0; c < col; c++) { + dst_ptr[rd8 * col * 8 + c * 8 + rm8] = src_ptr[r * col + c]; + } + } + return; +} + +void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep, + const int32_t a_zp, const int32_t b_zp) { + /* col8-major * row8-major => row8x8-major */ + for (int row = 0; row < row8; row++) { + for (int col = 0; col < col8; col++) { + int r8div = row / 8, r8mod = row % 8; + int c8div = col / 8, c8mod = col % 8; + size_t ci = c8div * row8 * 8 + row * 8 + c8mod; + int32_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r8div * deep * 8 + d * 8 + r8mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + ((int32_t)a[ai] - a_zp) * ((int32_t)b[bi] - b_zp); + } + c[ci] = value; + } + } + return; +} + +// todo: need to delete, replace by above functions. z00445833 +void GemmRowCol8x8Major2RowMajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + int col8 = UP_ROUND(col, 8); + for (int r = 0; r < row; r++) { + int rd8 = r / 8; + int rm8 = r % 8; + for (int c = 0; c < col; c++) { + dst_ptr[r * col + c] = src_ptr[rd8 * col8 * 8 + c * 8 + rm8]; + } + } +} + +void Gemm8x8Int8(const int8_t *lhs_data, const int8_t *rhs_data, const int8_t *bias_data, int8_t *output_data, + int depth, FcQuantArg *params) { + int lhs_offset = params->input.zp_; + int rhs_offset = params->weight.zp_; + int output_offset = params->output.zp_; + int output_multiplier = params->quant_multiplier; + int output_shift = params->output_shift; + + for (int row = 0; row < 8; ++row) { + for (int col = 0; col < 8; ++col) { + int c_index = col * 8 + row; + int acc = 0; + for (int d = 0; d < depth; ++d) { + int a_index = d * 8 + row; + int b_index = d * 8 + col; + acc += (lhs_data[a_index] - lhs_offset) * (rhs_data[b_index] - rhs_offset); + } + acc += bias_data[col]; + acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift, output_shift) + output_offset; + acc = MSMAX(CHAR_MIN, MSMIN(CHAR_MAX, acc)); + output_data[c_index] = (int8_t)acc; + } + } +} + +void GemmInt8(const int8_t *input_data, const int8_t *weights_data, const int8_t *bias_data, int8_t *output_data, + int row_8, int col_8, int depth, FcQuantArg *params) { + for (int r = 0; r < row_8; r += 8) { + int8_t *output = output_data + r * col_8; + const int8_t *input = input_data + r * depth; + for (int c = 0; c < col_8; c += 8) { + const int8_t *bias = bias_data + c; + const int8_t *weights = weights_data + c * depth; + int8_t *dst = output + c * 8; + Gemm8x8Int8(input, weights, bias, dst, depth, params); + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.h b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.h new file mode 100644 index 00000000000..d0c85ba5a19 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_MATMUL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_MATMUL_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/matmul.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep, + const int32_t a_zp, const int32_t b_zp); +void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); + +void GemmRowCol8x8Major2RowMajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void Gemm8x8Int8(const int8_t *lhs_data, const int8_t *rhs_data, const int8_t *bias_data, int8_t *output_data, + int depth, FcQuantArg *params); +void GemmInt8(const int8_t *input_data, const int8_t *weights_data, const int8_t *bias_data, int8_t *output_data, + int row_8, int col_8, int depth, FcQuantArg *params); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_MATMUL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/mul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/mul_int8.cc new file mode 100644 index 00000000000..8b493f51e9e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/mul_int8.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/int8/mul_int8.h" +#include "src/runtime/kernel/arm/opclib/mul_parameter.h" +#ifdef ENABLE_NEON +#include +#include "src/runtime/kernel/arm/opclib/add_int8.h" +#endif +#include "src/runtime/kernel/arm/opclib/quantization/fixed_point.h" + +#ifdef ENABLE_NEON + +int16x4_t ClacSumHalfWord(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec, + int32x4_t output_multiplier_vec, MulQuantArg para) { + int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1); + int32x4_t raw_sum = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), + para.shift_right_); + raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(para.out_quant_arg_.zp_)); + raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(para.output_activation_min_)); + raw_sum = vminq_s32(raw_sum, vdupq_n_s32(para.output_activation_max_)); + return vqmovn_s32(raw_sum); +} + +void MulInt8NEON(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + MulQuantArg para, int *index) { + int32x4_t output_multiplier_vec = vdupq_n_s32(para.output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << para.shift_left_); + + for (; (*index) <= real_dst_count - 8; (*index) += 8) { + int16x8_t input0_val = LoadAndAddOffset(input0_data, *index, para.in_quant_args_[0].zp_); + int16x8_t input1_val = LoadAndAddOffset(input1_data, *index, para.in_quant_args_[1].zp_); + + int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); + int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); + int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); + int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); + + int16x4_t sum_low = + ClacSumHalfWord(input0_low, input1_low, left_shift_out_vec, output_multiplier_vec, para); + int16x4_t sum_high = + ClacSumHalfWord(input0_high, input1_high, left_shift_out_vec, output_multiplier_vec, para); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(output_data, res_u8_n0); + } +} +#endif + +void Mul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, MulQuantArg para) { + int index = 0; +#ifdef ENABLE_NEON + MulInt8NEON(input0_data, input1_data, output_data, real_dst_count, para, &index); +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = para.in_quant_args_[0].zp_ + input0_data[index]; + const int32_t input1_val = para.in_quant_args_[1].zp_ + input1_data[index]; + int32_t mul_result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << para.shift_left_), + para.output_multiplier_), para.shift_right_); + + mul_result += para.out_quant_arg_.zp_; + + if (mul_result > para.output_activation_max_) { + output_data[index] = para.output_activation_max_; + } else if (mul_result < para.output_activation_min_) { + output_data[index] = para.output_activation_min_; + } else { + output_data[index] = static_cast(mul_result); + } + } + return; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/mul_int8.h b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/mul_int8.h new file mode 100644 index 00000000000..9e6a4aa4b80 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/mul_int8.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_MUL_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_MUL_INT8_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/mul_parameter.h" + +void Mul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, MulQuantArg para); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_MUL_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/pooling_int8.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/pooling_int8.cc new file mode 100644 index 00000000000..3f621211cc0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/pooling_int8.cc @@ -0,0 +1,372 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/int8/pooling_int8.h" +#include "src/runtime/kernel/arm/opclib/common_func.h" + +void AvgPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + int out_plane = output_w * output_h; + int16_t out_min = INT8_MIN; + int16_t out_max = INT8_MAX; + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int i = 0; i < out_plane; i++) { + int out_w_index = i % output_w; + int out_h_index = i / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + i * channel; + for (int j = 0; j < channel; j++) { + int in_channel_offset = in_batch_offset + j; + int out_channel_offset = out_plane_offset + j; + int16_t tmp_avg = 0; + int real_count = 0; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg += *(input_ptr + in_offset); + ++real_count; + } + } // win_w loop + } // win_h loop + int16_t tmp_out = round((float)tmp_avg / (float)real_count); + int16_t real_out = tmp_out < out_min ? out_min : tmp_out; + real_out = real_out > out_max ? out_max : real_out; + *(output_ptr + out_channel_offset) = (int8_t)real_out; + } // in_channel loop + } // out_plane loop + } // out_batch loop +} + +void AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + int thread_num = pooling_param->thread_num_; + int c8 = UP_DIV(channel, C8NUM); + int8_t out_min = INT8_MIN; + int8_t out_max = INT8_MAX; + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + index * channel; + for (int j = 0; j < c8 - 1; j++) { + int in_channel_offset = in_batch_offset + j * C8NUM; + int out_channel_offset = out_plane_offset + j * C8NUM; + int16_t tmp_avg1 = 0; + int16_t tmp_avg2 = 0; + int16_t tmp_avg3 = 0; + int16_t tmp_avg4 = 0; + int16_t tmp_avg5 = 0; + int16_t tmp_avg6 = 0; + int16_t tmp_avg7 = 0; + int16_t tmp_avg8 = 0; + int real_count = 0; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg1 += *(input_ptr + in_offset); + tmp_avg2 += *(input_ptr + in_offset + 1); + tmp_avg3 += *(input_ptr + in_offset + 2); + tmp_avg4 += *(input_ptr + in_offset + 3); + tmp_avg5 += *(input_ptr + in_offset + 4); + tmp_avg6 += *(input_ptr + in_offset + 5); + tmp_avg7 += *(input_ptr + in_offset + 6); + tmp_avg8 += *(input_ptr + in_offset + 7); + ++real_count; + } + } // win_w loop + } // win_h loop + int16_t tmp_out1 = round((float)tmp_avg1 / (float)real_count); + int16_t tmp_out2 = round((float)tmp_avg2 / (float)real_count); + int16_t tmp_out3 = round((float)tmp_avg3 / (float)real_count); + int16_t tmp_out4 = round((float)tmp_avg4 / (float)real_count); + int16_t tmp_out5 = round((float)tmp_avg5 / (float)real_count); + int16_t tmp_out6 = round((float)tmp_avg6 / (float)real_count); + int16_t tmp_out7 = round((float)tmp_avg7 / (float)real_count); + int16_t tmp_out8 = round((float)tmp_avg8 / (float)real_count); + int16_t real_out1 = tmp_out1 < out_min ? out_min : tmp_out1; + int16_t real_out2 = tmp_out2 < out_min ? out_min : tmp_out2; + int16_t real_out3 = tmp_out3 < out_min ? out_min : tmp_out3; + int16_t real_out4 = tmp_out4 < out_min ? out_min : tmp_out4; + int16_t real_out5 = tmp_out5 < out_min ? out_min : tmp_out5; + int16_t real_out6 = tmp_out6 < out_min ? out_min : tmp_out6; + int16_t real_out7 = tmp_out7 < out_min ? out_min : tmp_out7; + int16_t real_out8 = tmp_out8 < out_min ? out_min : tmp_out8; + real_out1 = real_out1 > out_max ? out_max : real_out1; + real_out2 = real_out2 > out_max ? out_max : real_out2; + real_out3 = real_out3 > out_max ? out_max : real_out3; + real_out4 = real_out4 > out_max ? out_max : real_out4; + real_out5 = real_out5 > out_max ? out_max : real_out5; + real_out6 = real_out6 > out_max ? out_max : real_out6; + real_out7 = real_out7 > out_max ? out_max : real_out7; + real_out8 = real_out8 > out_max ? out_max : real_out8; + *(output_ptr + out_channel_offset) = (int8_t)real_out1; + *(output_ptr + out_channel_offset + 1) = (int8_t)real_out2; + *(output_ptr + out_channel_offset + 2) = (int8_t)real_out3; + *(output_ptr + out_channel_offset + 3) = (int8_t)real_out4; + *(output_ptr + out_channel_offset + 4) = (int8_t)real_out5; + *(output_ptr + out_channel_offset + 5) = (int8_t)real_out6; + *(output_ptr + out_channel_offset + 6) = (int8_t)real_out7; + *(output_ptr + out_channel_offset + 7) = (int8_t)real_out8; + } // in_channel loop + int channel_s = (c8 - 1) * C8NUM; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + int16_t tmp_avg = 0; + int real_count = 0; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg += *(input_ptr + in_offset); + ++real_count; + } + } // win_w loop + } // win_h loop + int16_t tmp_out = round((float)tmp_avg / (float)real_count); + int16_t real_out = tmp_out < out_min ? out_min : tmp_out; + real_out = real_out > out_max ? out_max : real_out; + *(output_ptr + out_channel_offset) = (int8_t)real_out; + } // channel_res loop + } // out_plane loop + } // out_batch loop + } +} + +void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + int out_plane = output_w * output_h; + // input channel is equal to output channel + float input_scale = pooling_param->quant_args_[0][0].scale_; + int input_zp = pooling_param->quant_args_[0][0].zp_; + float output_scale = pooling_param->quant_args_[1][0].scale_; + int output_zp = pooling_param->quant_args_[1][0].zp_; + double real_multiplier = input_scale / output_scale; + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int i = 0; i < out_plane; i++) { + int out_w_index = i % output_w; + int out_h_index = i / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + i * channel; + for (int j = 0; j < channel; j++) { + int in_channel_offset = in_batch_offset + j; + int out_channel_offset = out_plane_offset + j; + int8_t tmp_max = INT8_MIN; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset)); + } + } // win_w loop + } // win_h loop + *(output_ptr + out_channel_offset) = (int8_t)(round((tmp_max - input_zp) * real_multiplier) + output_zp); + } // in_channel loop + } // out_plane loop + } // out_batch loop +} + +void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + int thread_num = pooling_param->thread_num_; + int c16 = UP_DIV(channel, 16); + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + index * channel; + for (int j = 0; j < c16 - 1; j++) { + int in_channel_offset = in_batch_offset + j * 16; + int out_channel_offset = out_plane_offset + j * 16; +#ifdef ENABLE_NEON + int8x16_t tmp_max = vdupq_n_s8(INT8_MIN); +#else + int8_t tmp_max1 = INT8_MIN; + int8_t tmp_max2 = INT8_MIN; + int8_t tmp_max3 = INT8_MIN; + int8_t tmp_max4 = INT8_MIN; + int8_t tmp_max5 = INT8_MIN; + int8_t tmp_max6 = INT8_MIN; + int8_t tmp_max7 = INT8_MIN; + int8_t tmp_max8 = INT8_MIN; + int8_t tmp_max9 = INT8_MIN; + int8_t tmp_max10 = INT8_MIN; + int8_t tmp_max11 = INT8_MIN; + int8_t tmp_max12 = INT8_MIN; + int8_t tmp_max13 = INT8_MIN; + int8_t tmp_max14 = INT8_MIN; + int8_t tmp_max15 = INT8_MIN; + int8_t tmp_max16 = INT8_MIN; +#endif + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_max = vmaxq_s8(tmp_max, vld1q_s8(input_ptr + in_offset)); +#else + tmp_max1 = MaxInt8(tmp_max1, *(input_ptr + in_offset)); + tmp_max2 = MaxInt8(tmp_max2, *(input_ptr + in_offset + 1)); + tmp_max3 = MaxInt8(tmp_max3, *(input_ptr + in_offset + 2)); + tmp_max4 = MaxInt8(tmp_max4, *(input_ptr + in_offset + 3)); + tmp_max5 = MaxInt8(tmp_max5, *(input_ptr + in_offset + 4)); + tmp_max6 = MaxInt8(tmp_max6, *(input_ptr + in_offset + 5)); + tmp_max7 = MaxInt8(tmp_max7, *(input_ptr + in_offset + 6)); + tmp_max8 = MaxInt8(tmp_max8, *(input_ptr + in_offset + 7)); + tmp_max9 = MaxInt8(tmp_max9, *(input_ptr + in_offset + 8)); + tmp_max10 = MaxInt8(tmp_max10, *(input_ptr + in_offset + 9)); + tmp_max11 = MaxInt8(tmp_max11, *(input_ptr + in_offset + 10)); + tmp_max12 = MaxInt8(tmp_max12, *(input_ptr + in_offset + 11)); + tmp_max13 = MaxInt8(tmp_max13, *(input_ptr + in_offset + 12)); + tmp_max14 = MaxInt8(tmp_max14, *(input_ptr + in_offset + 13)); + tmp_max15 = MaxInt8(tmp_max15, *(input_ptr + in_offset + 14)); + tmp_max16 = MaxInt8(tmp_max16, *(input_ptr + in_offset + 15)); +#endif + } + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + vst1q_s8(output_ptr + out_channel_offset, tmp_max); +#else + *(output_ptr + out_channel_offset) = tmp_max1; + *(output_ptr + out_channel_offset + 1) = tmp_max2; + *(output_ptr + out_channel_offset + 2) = tmp_max3; + *(output_ptr + out_channel_offset + 3) = tmp_max4; + *(output_ptr + out_channel_offset + 4) = tmp_max5; + *(output_ptr + out_channel_offset + 5) = tmp_max6; + *(output_ptr + out_channel_offset + 6) = tmp_max7; + *(output_ptr + out_channel_offset + 7) = tmp_max8; + *(output_ptr + out_channel_offset + 8) = tmp_max9; + *(output_ptr + out_channel_offset + 9) = tmp_max10; + *(output_ptr + out_channel_offset + 10) = tmp_max11; + *(output_ptr + out_channel_offset + 11) = tmp_max12; + *(output_ptr + out_channel_offset + 12) = tmp_max13; + *(output_ptr + out_channel_offset + 13) = tmp_max14; + *(output_ptr + out_channel_offset + 14) = tmp_max15; + *(output_ptr + out_channel_offset + 15) = tmp_max16; +#endif + } // in_channel loop + int channel_s = (c16 - 1) * 16; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + int8_t tmp_max = INT8_MIN; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset)); + } + } // win_w loop + } // win_h loop + *(output_ptr + out_channel_offset) = tmp_max; + } // channel_res loop + } // out_plane loop + } // out_batch loop + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/pooling_int8.h b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/pooling_int8.h new file mode 100644 index 00000000000..9636159e7f7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/pooling_int8.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_POOLING_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_POOLING_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/fp32/pooling.h" + +void AvgPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id); + +void AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id); + +void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id); + +void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_POOLING_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/reshape_int8.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/reshape_int8.cc new file mode 100644 index 00000000000..aabed8b3afc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/reshape_int8.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/int8/reshape_int8.h" +#include + +void Reshape(int8_t *input_ptr, int8_t *output_ptr, size_t data_size, int input_num, QuantArg in_quant_arg, + QuantArg out_quant_arg) { + if (in_quant_arg.scale_ == out_quant_arg.scale_ && in_quant_arg.zp_ == out_quant_arg.zp_) { + memcpy(output_ptr, input_ptr, data_size); + } else { + float output_inverse_scale = 1.f / out_quant_arg.scale_; + float scale = in_quant_arg.scale_ * output_inverse_scale; + float bias = -in_quant_arg.zp_ * scale; + int32_t output_zp = out_quant_arg.zp_; + for (int i = 0; i < input_num; i++) { + int32_t output_tmp = round(input_ptr[i] * scale + bias) + output_zp; + if (output_tmp > 127) { + output_ptr[i] = 127; + } else if (output_tmp < -128) { + output_ptr[i] = -128; + } else { + output_ptr[i] = (int8_t)output_tmp; + } + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/reshape_int8.h b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/reshape_int8.h new file mode 100644 index 00000000000..8ba2e835376 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/reshape_int8.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_RESHAHPE_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_RESHAHPE_INT8_H_ +#include "src/runtime/kernel/arm/opclib/op_base.h" + +void Reshape(int8_t *input_ptr, int8_t *output_ptr, size_t data_size, int input_num, QuantArg in_quant_arg, + QuantArg out_quant_arg); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_RESHAHPE_INT8_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/matmul.h b/mindspore/lite/src/runtime/kernel/arm/opclib/matmul.h new file mode 100644 index 00000000000..f83ae884ef6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/matmul.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_MATMUL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_MATMUL_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct MatMulParameter { + OpParameter op_parameter_; + int row_; + int col_; + int row_8_; + int col_8_; + int deep_; + float minf_; + float maxf_; + bool has_bias_; + bool a_transpose_; /* false : row-major */ + bool b_transpose_; /* true : col-major */ +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_MATMUL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/matrix_table.h b/mindspore/lite/src/runtime/kernel/arm/opclib/matrix_table.h new file mode 100644 index 00000000000..3b9316d6205 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/matrix_table.h @@ -0,0 +1,512 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_MATRIX_TABLE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_MATRIX_TABLE_H_ + +inline void MatrixG4x2(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 0.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 0.5f; + matrix_data[4] = 1.0f; + matrix_data[5] = -0.5f; + matrix_data[6] = 0.0f; + matrix_data[7] = 1.0f; +} + +inline void MatrixGT2x4(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 1.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 0.0f; + matrix_data[4] = 0.0f; + matrix_data[5] = 0.5f; + matrix_data[6] = -0.5f; + matrix_data[7] = 1.0f; +} + +inline void MatrixG8x2(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 0.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 0.5f; + matrix_data[4] = 1.0f; + matrix_data[5] = -0.5f; + matrix_data[6] = 1.0f; + matrix_data[7] = 1.0f; + matrix_data[8] = 1.0f; + matrix_data[9] = -1.0f; + matrix_data[10] = 1.0f; + matrix_data[11] = 1.5f; + matrix_data[12] = 1.0f; + matrix_data[13] = -1.5f; + matrix_data[14] = 0.0f; + matrix_data[15] = 1.0f; +} + +inline void MatrixGT2x8(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 1.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 1.5f; + matrix_data[4] = 1.0f; + matrix_data[5] = 1.0f; + matrix_data[6] = 1.0f; + matrix_data[7] = 0.0f; + matrix_data[8] = 0.0f; + matrix_data[9] = 0.5f; + matrix_data[10] = -0.5f; + matrix_data[11] = 1.0f; + matrix_data[12] = -1.0f; + matrix_data[13] = 1.5f; + matrix_data[14] = -1.5f; + matrix_data[15] = 1.0f; +} + +inline void MatrixG8x3(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 0.0f; + matrix_data[2] = 0.0f; + matrix_data[3] = 1.0f; + matrix_data[4] = 0.5f; + matrix_data[5] = 0.25f; + matrix_data[6] = 1.0f; + matrix_data[7] = -0.5f; + matrix_data[8] = 0.25f; + matrix_data[9] = 1.0f; + matrix_data[10] = 1.0f; + matrix_data[11] = 1.0f; + matrix_data[12] = 1.0f; + matrix_data[13] = -1.0f; + matrix_data[14] = 1.0f; + matrix_data[15] = 1.0f; + matrix_data[16] = 1.5f; + matrix_data[17] = 2.25f; + matrix_data[18] = 1.0f; + matrix_data[19] = -1.5f; + matrix_data[20] = 2.25f; + matrix_data[21] = 0.0f; + matrix_data[22] = 0.0f; + matrix_data[23] = 1.0f; +} + +inline void MatrixGT3x8(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 1.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 1.0f; + matrix_data[4] = 1.0f; + matrix_data[5] = 1.0f; + matrix_data[6] = 1.0f; + matrix_data[7] = 0.0f; + matrix_data[8] = 0.0f; + matrix_data[9] = 0.5f; + matrix_data[10] = -0.5f; + matrix_data[11] = 1.0f; + matrix_data[12] = -1.0f; + matrix_data[13] = 1.5f; + matrix_data[14] = -1.5f; + matrix_data[15] = 0.0f; + matrix_data[16] = 0.0f; + matrix_data[17] = 0.25f; + matrix_data[18] = 0.25f; + matrix_data[19] = 1.0f; + matrix_data[20] = 1.0f; + matrix_data[21] = 2.25f; + matrix_data[22] = 2.25f; + matrix_data[23] = 1.0f; +} + +inline void MatrixG8x4(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 0.0f; + matrix_data[2] = 0.0f; + matrix_data[3] = 0.0f; + matrix_data[4] = 1.0f; + matrix_data[5] = 0.5f; + matrix_data[6] = 0.25f; + matrix_data[7] = 0.125f; + matrix_data[8] = 1.0f; + matrix_data[9] = -0.5f; + matrix_data[10] = 0.25f; + matrix_data[11] = -0.125f; + matrix_data[12] = 1.0f; + matrix_data[13] = 1.0f; + matrix_data[14] = 1.0f; + matrix_data[15] = 1.0f; + matrix_data[16] = 1.0f; + matrix_data[17] = -1.0f; + matrix_data[18] = 1.0f; + matrix_data[19] = -1.0f; + matrix_data[20] = 1.0f; + matrix_data[21] = 1.5f; + matrix_data[22] = 2.25f; + matrix_data[23] = 3.375f; + matrix_data[24] = 1.0f; + matrix_data[25] = -1.5f; + matrix_data[26] = 2.25f; + matrix_data[27] = -3.375f; + matrix_data[28] = 0.0f; + matrix_data[29] = 0.0f; + matrix_data[30] = 0.0f; + matrix_data[31] = 1.0f; +} + +inline void MatrixGT4x8(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 1.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 1.0f; + matrix_data[4] = 1.0f; + matrix_data[5] = 1.0f; + matrix_data[6] = 1.0f; + matrix_data[7] = 0.0f; + matrix_data[8] = 0.0f; + matrix_data[9] = 0.5f; + matrix_data[10] = -0.5f; + matrix_data[11] = 1.0f; + matrix_data[12] = -1.0f; + matrix_data[13] = 1.5f; + matrix_data[14] = -1.5f; + matrix_data[15] = 0.0f; + matrix_data[16] = 0.0f; + matrix_data[17] = 0.25f; + matrix_data[18] = 0.25f; + matrix_data[19] = 1.0f; + matrix_data[20] = 1.0f; + matrix_data[21] = 2.25f; + matrix_data[22] = 2.25f; + matrix_data[23] = 0.0f; + matrix_data[24] = 0.0f; + matrix_data[25] = 0.125f; + matrix_data[26] = -0.125f; + matrix_data[27] = 1.0f; + matrix_data[28] = -1.0f; + matrix_data[29] = 3.375f; + matrix_data[30] = -3.375f; + matrix_data[31] = 1.0f; +} + +inline void MatrixG8x5(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 0.0f; + matrix_data[2] = 0.0f; + matrix_data[3] = 0.0f; + matrix_data[4] = 0.0f; + matrix_data[5] = 1.0f; + matrix_data[6] = 0.5f; + matrix_data[7] = 0.25f; + matrix_data[8] = 0.125f; + matrix_data[9] = 0.0625f; + matrix_data[10] = 1.0f; + matrix_data[11] = -0.5f; + matrix_data[12] = 0.25f; + matrix_data[13] = -0.125f; + matrix_data[14] = 0.0625f; + matrix_data[15] = 1.0f; + matrix_data[16] = 1.0f; + matrix_data[17] = 1.0f; + matrix_data[18] = 1.0f; + matrix_data[19] = 1.0f; + matrix_data[20] = 1.0f; + matrix_data[21] = -1.0f; + matrix_data[22] = 1.0f; + matrix_data[23] = -1.0f; + matrix_data[24] = 1.0f; + matrix_data[25] = 1.0f; + matrix_data[26] = 1.5f; + matrix_data[27] = 2.25f; + matrix_data[28] = 3.375f; + matrix_data[29] = 5.0625f; + matrix_data[30] = 1.0f; + matrix_data[31] = -1.5f; + matrix_data[32] = 2.25f; + matrix_data[33] = -3.375f; + matrix_data[34] = 5.0625f; + matrix_data[35] = 0.0f; + matrix_data[36] = 0.0f; + matrix_data[37] = 0.0f; + matrix_data[38] = 0.0f; + matrix_data[39] = 1.0f; +} + +inline void MatrixGT5x8(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 1.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 1.0f; + matrix_data[4] = 1.0f; + matrix_data[5] = 1.0f; + matrix_data[6] = 1.0f; + matrix_data[7] = 0.0f; + matrix_data[8] = 0.0f; + matrix_data[9] = 0.5f; + matrix_data[10] = -0.5f; + matrix_data[11] = 1.0f; + matrix_data[12] = -1.0f; + matrix_data[13] = 1.5f; + matrix_data[14] = -1.5f; + matrix_data[15] = 0.0f; + matrix_data[16] = 0.0f; + matrix_data[17] = 0.25f; + matrix_data[18] = 0.25f; + matrix_data[19] = 1.0f; + matrix_data[20] = 1.0f; + matrix_data[21] = 2.25f; + matrix_data[22] = 2.25f; + matrix_data[23] = 0.0f; + matrix_data[24] = 0.0f; + matrix_data[25] = 0.125f; + matrix_data[26] = -0.125f; + matrix_data[27] = 1.0f; + matrix_data[28] = -1.0f; + matrix_data[29] = 3.375f; + matrix_data[30] = -3.375f; + matrix_data[31] = 0.0f; + matrix_data[32] = 0.0f; + matrix_data[33] = 0.0625f; + matrix_data[34] = 0.0625f; + matrix_data[35] = 1.0f; + matrix_data[36] = 1.0f; + matrix_data[37] = 5.0625f; + matrix_data[38] = 5.0625f; + matrix_data[39] = 1.0f; +} + +inline void MatrixG8x6(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 0.0f; + matrix_data[2] = 0.0f; + matrix_data[3] = 0.0f; + matrix_data[4] = 0.0f; + matrix_data[5] = 0.0f; + matrix_data[6] = 1.0f; + matrix_data[7] = 0.5f; + matrix_data[8] = 0.25f; + matrix_data[9] = 0.125f; + matrix_data[10] = 0.0625f; + matrix_data[11] = 0.03125f; + matrix_data[12] = 1.0f; + matrix_data[13] = -0.5f; + matrix_data[14] = 0.25f; + matrix_data[15] = -0.125f; + matrix_data[16] = 0.0625f; + matrix_data[17] = -0.03125f; + matrix_data[18] = 1.0f; + matrix_data[19] = 1.0f; + matrix_data[20] = 1.0f; + matrix_data[21] = 1.0f; + matrix_data[22] = 1.0f; + matrix_data[23] = 1.0f; + matrix_data[24] = 1.0f; + matrix_data[25] = -1.0f; + matrix_data[26] = 1.0f; + matrix_data[27] = -1.0f; + matrix_data[28] = 1.0f; + matrix_data[29] = -1.0f; + matrix_data[30] = 1.0f; + matrix_data[31] = 1.5f; + matrix_data[32] = 2.25f; + matrix_data[33] = 3.375f; + matrix_data[34] = 5.0625f; + matrix_data[35] = 7.59375f; + matrix_data[36] = 1.0f; + matrix_data[37] = -1.5f; + matrix_data[38] = 2.25f; + matrix_data[39] = -3.375f; + matrix_data[40] = 5.0625f; + matrix_data[41] = -7.59375f; + matrix_data[42] = 0.0f; + matrix_data[43] = 0.0f; + matrix_data[44] = 0.0f; + matrix_data[45] = 0.0f; + matrix_data[46] = 0.0f; + matrix_data[47] = 1.0f; +} + +inline void MatrixGT6x8(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 1.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 1.0f; + matrix_data[4] = 1.0f; + matrix_data[5] = 1.0f; + matrix_data[6] = 1.0f; + matrix_data[7] = 0.0f; + matrix_data[8] = 0.0f; + matrix_data[9] = 0.5f; + matrix_data[10] = -0.5f; + matrix_data[11] = 1.0f; + matrix_data[12] = -1.0f; + matrix_data[13] = 1.5f; + matrix_data[14] = -1.5f; + matrix_data[15] = 0.0f; + matrix_data[16] = 0.0f; + matrix_data[17] = 0.25f; + matrix_data[18] = 0.25f; + matrix_data[19] = 1.0f; + matrix_data[20] = 1.0f; + matrix_data[21] = 2.25f; + matrix_data[22] = 2.25f; + matrix_data[23] = 0.0f; + matrix_data[24] = 0.0f; + matrix_data[25] = 0.125f; + matrix_data[26] = -0.125f; + matrix_data[27] = 1.0f; + matrix_data[28] = -1.0f; + matrix_data[29] = 3.375f; + matrix_data[30] = -3.375f; + matrix_data[31] = 0.0f; + matrix_data[32] = 0.0f; + matrix_data[33] = 0.0625f; + matrix_data[34] = 0.0625f; + matrix_data[35] = 1.0f; + matrix_data[36] = 1.0f; + matrix_data[37] = 5.0625f; + matrix_data[38] = 5.0625f; + matrix_data[39] = 0.0f; + matrix_data[40] = 0.0; + matrix_data[41] = 0.03125f; + matrix_data[42] = -0.03125f; + matrix_data[43] = 1.0f; + matrix_data[44] = -1.0f; + matrix_data[45] = 7.59375f; + matrix_data[46] = -7.59375f; + matrix_data[47] = 0.0f; + matrix_data[48] = 1.0f; +} + +inline void MatrixG8x7(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 0.0f; + matrix_data[2] = 0.0f; + matrix_data[3] = 0.0f; + matrix_data[4] = 0.0f; + matrix_data[5] = 0.0f; + matrix_data[6] = 0.0f; + matrix_data[7] = 1.0f; + matrix_data[8] = 0.5f; + matrix_data[9] = 0.25f; + matrix_data[10] = 0.125f; + matrix_data[11] = 0.0625f; + matrix_data[12] = 0.03125f; + matrix_data[13] = 0.015625f; + matrix_data[14] = 1.0f; + matrix_data[15] = -0.5f; + matrix_data[16] = 0.25f; + matrix_data[17] = -0.125f; + matrix_data[18] = 0.0625f; + matrix_data[19] = -0.03125f; + matrix_data[20] = 0.015625f; + matrix_data[21] = 1.0f; + matrix_data[22] = 1.0f; + matrix_data[23] = 1.0f; + matrix_data[24] = 1.0f; + matrix_data[25] = 1.0f; + matrix_data[26] = 1.0f; + matrix_data[27] = 1.0f; + matrix_data[28] = 1.0f; + matrix_data[29] = -1.0f; + matrix_data[30] = 1.0f; + matrix_data[31] = -1.0f; + matrix_data[32] = 1.0f; + matrix_data[33] = -1.0f; + matrix_data[34] = 1.0f; + matrix_data[35] = 1.0f; + matrix_data[36] = 1.5f; + matrix_data[37] = 2.25f; + matrix_data[38] = 3.375f; + matrix_data[39] = 5.0625f; + matrix_data[40] = 7.59375f; + matrix_data[41] = 11.390625f; + matrix_data[42] = 1.0f; + matrix_data[43] = -1.5f; + matrix_data[44] = 2.25f; + matrix_data[45] = -3.375f; + matrix_data[46] = 5.0625f; + matrix_data[47] = -7.59375f; + matrix_data[48] = 11.390625f; + matrix_data[49] = 0.0f; + matrix_data[50] = 0.0f; + matrix_data[51] = 0.0f; + matrix_data[52] = 0.0f; + matrix_data[53] = 0.0f; + matrix_data[54] = 0.0f; + matrix_data[55] = 1.0f; +} + +inline void MatrixGT7x8(float *matrix_data) { + matrix_data[0] = 1.0f; + matrix_data[1] = 1.0f; + matrix_data[2] = 1.0f; + matrix_data[3] = 1.0f; + matrix_data[4] = 1.0f; + matrix_data[5] = 1.0f; + matrix_data[6] = 1.0f; + matrix_data[7] = 0.0f; + matrix_data[8] = 0.0f; + matrix_data[9] = 0.5f; + matrix_data[10] = -0.5f; + matrix_data[11] = 1.0f; + matrix_data[12] = -1.0f; + matrix_data[13] = 1.5f; + matrix_data[14] = -1.5f; + matrix_data[15] = 0.0f; + matrix_data[16] = 0.0f; + matrix_data[17] = 0.25f; + matrix_data[18] = 0.25f; + matrix_data[19] = 1.0f; + matrix_data[20] = 1.0f; + matrix_data[21] = 2.25f; + matrix_data[22] = 2.25f; + matrix_data[23] = 0.0f; + matrix_data[24] = 0.0f; + matrix_data[25] = 0.125f; + matrix_data[26] = -0.125f; + matrix_data[27] = 1.0f; + matrix_data[28] = -1.0f; + matrix_data[29] = 3.375f; + matrix_data[30] = -3.375f; + matrix_data[31] = 0.0f; + matrix_data[32] = 0.0f; + matrix_data[33] = 0.0625f; + matrix_data[34] = 0.0625f; + matrix_data[35] = 1.0f; + matrix_data[36] = 1.0f; + matrix_data[37] = 5.0625f; + matrix_data[38] = 5.0625f; + matrix_data[39] = 0.0f; + matrix_data[40] = 0.0; + matrix_data[41] = 0.03125f; + matrix_data[42] = -0.03125f; + matrix_data[43] = 1.0f; + matrix_data[44] = -1.0f; + matrix_data[45] = 7.59375f; + matrix_data[46] = -7.59375f; + matrix_data[47] = 0.0f; + matrix_data[48] = 0.0f; + matrix_data[49] = 0.015625f; + matrix_data[50] = 0.015625f; + matrix_data[51] = 1.0f; + matrix_data[52] = 1.0f; + matrix_data[53] = 11.390625f; + matrix_data[54] = 11.390625f; + matrix_data[55] = 1.0f; +} + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_MATRIX_TABLE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/mul_parameter.h b/mindspore/lite/src/runtime/kernel/arm/opclib/mul_parameter.h new file mode 100644 index 00000000000..e35a5ba4cde --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/mul_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_MUL_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_MUL_PARAMETER_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct MulParameter { + OpParameter op_parameter_; + int thread_count_; + MulQuantArg mul_quant_arg_; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_MUL_PARAMETER_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/offset_utils.h b/mindspore/lite/src/runtime/kernel/arm/opclib/offset_utils.h new file mode 100644 index 00000000000..3e9d3f560fd --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/offset_utils.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_OFFSET_UTILS_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_OFFSET_UTILS_H_ + +#ifdef ENABLE_NEON +#include +#endif + +inline int offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3) { + return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3] + dim3; +} + +inline int offset4d(const int *shape, const int *dims) { return offset(shape, dims[0], dims[1], dims[2], dims[3]); } +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_OFFSET_UTILS_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/op_base.h b/mindspore/lite/src/runtime/kernel/arm/opclib/op_base.h new file mode 100644 index 00000000000..0084b83fffb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/op_base.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_OP_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_OP_BASE_H_ + +#include +#include "src/runtime/kernel/arm/opclib/quantization/quantize.h" + +#define C4NUM 4 +#define C8NUM 8 +#define BLOCK 4 +#define TILE_NUM 8 + +#define MSMIN(x, y) ((x) < (y) ? (x) : (y)) +#define MSMAX(x, y) ((x) > (y) ? (x) : (y)) + +#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) +#define UP_ROUND(x, y) (((x) + (y) - (1)) / (y) * (y)) +#define UP_ROUND_DIV(x, y) (x % y == 0 ? (x / y) : (x / y) + 1) +#define DOWN_DIV(x, y) (((x) - (y) + (1)) / (y)) + +#define MSVALID(left, x, right) (MSMIN((MSMAX(left, x)), right)) + +#define DIMENSION_4D 4 + +#define kInputIndex 0 +#define kWeightIndex 1 +#define kBiasIndex 2 +#define kOutputIndex 0 +#define kNHWC_N 0 +#define kNHWC_H 1 +#define kNHWC_W 2 +#define kNHWC_C 3 +#define kInputSize1 2 +#define kInputSize2 3 + +struct OpParameter { + char name_[100]; + int type_; + int thread_num_; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_OP_BASE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/opclib_utils.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/opclib_utils.cc new file mode 100644 index 00000000000..c43f23e6cbb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/opclib_utils.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/opclib_utils.h" +#ifdef __ANDROID__ +#include +#endif + +#if defined(__ANDROID__) +uint32_t getHwCap(int hwcap_type) { + uint32_t ret = getauxval(hwcap_type); + return ret; +} +#endif + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/opclib_utils.h b/mindspore/lite/src/runtime/kernel/arm/opclib/opclib_utils.h new file mode 100644 index 00000000000..4f8035a107f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/opclib_utils.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_OPCLIB_UTILS_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_OPCLIB_UTILS_H_ + +#include + +#if defined(__arm__) || defined(__aarch64__) +uint32_t getHwCap(int hwcap_type); +#endif + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_OPCLIB_UTILS_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/opt_op_handler.c b/mindspore/lite/src/runtime/kernel/arm/opclib/opt_op_handler.c new file mode 100644 index 00000000000..b6dec4f2e8b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/opt_op_handler.c @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +// todo +extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, + size_t ksize, size_t ic4, size_t output_channel, size_t offset, + const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, + size_t out_multiplier, size_t shift_before, size_t shift_after); + +void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, + size_t ksize, size_t ic4, size_t output_channel, size_t offset, + const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, + size_t out_multiplier, size_t shift_before, size_t shift_after) { + return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min, + act_max, out_zp, out_multiplier, shift_before, shift_after); +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/optimized_kernel.h b/mindspore/lite/src/runtime/kernel/arm/opclib/optimized_kernel.h new file mode 100644 index 00000000000..756056d6a60 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/optimized_kernel.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_OPTIMIZED_KERNEL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_OPTIMIZED_KERNEL_H_ + +#include +#ifdef __ANDROID__ +#include +#include "src/runtime/kernel/arm/opclib/opclib_utils.h" +#endif + +#define OPTIMIZE_SHARED_LIBRARY_PATH "liboptimize.so" + +class OptimizeModule { + public: + OptimizeModule() { + bool support_optimize_ops = false; + +#ifdef __ANDROID__ + int hwcap_type = 16; + uint32_t hwcap = getHwCap(hwcap_type); +#if defined(__aarch64__) + if (hwcap & HWCAP_ASIMDDP) { + printf("Hw cap support SMID Dot Product, hwcap: 0x%x \n", hwcap); + support_optimize_ops = true; + } else { + printf("Hw cap NOT support SIMD Dot Product, hwcap: 0x%x\n", hwcap); + } +#endif +#endif + if (!support_optimize_ops) { + return; + } + optimized_op_handler_ = dlopen(OPTIMIZE_SHARED_LIBRARY_PATH, RTLD_LAZY); + if (optimized_op_handler_ == nullptr) { + printf("Open optimize shared library failed.\n"); + } + } + + ~OptimizeModule() = default; + + static OptimizeModule *GetInstance() { + static OptimizeModule opt_module; + return &opt_module; + } + void *optimized_op_handler_ = nullptr; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_OPTIMIZED_KERNEL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/pack.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/pack.cc new file mode 100644 index 00000000000..453d9f8afe4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/pack.cc @@ -0,0 +1,1093 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/pack.h" +#include +#include + +#ifdef ENABLE_FP16 +void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num, + int block_index) { + // input format : nhwc + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int pad_h = conv_param->pad_h_; + int pad_w = conv_param->pad_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int channel_block = UP_DIV(in_channel, 4); + int kernel_plane = kernel_h * kernel_w; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * stride_h - pad_h; + int input_w = block_start % out_w * stride_w - pad_w; + for (int j = 0; j < kernel_h; j++) { + int input_y = input_h + j * dilation_h; + if (input_y < 0 || input_y >= in_h) { + continue; + } + int input_y_stride = input_y * in_w * channel_block * C4NUM; + for (int n = 0; n < kernel_w; n++) { + int input_x = input_w + n * dilation_w; + if (input_x < 0 || input_x >= in_w) { + continue; + } + int input_x_stride = input_y_stride + input_x * channel_block * C4NUM; + int input_plane_offset = (j * kernel_w + n) * 16 * C4NUM * channel_block + i * C4NUM; + for (int m = 0; m < channel_block; m++) { + int channel_block_stride = input_x_stride + m * C4NUM; + int channel_block_offset = input_plane_offset + m * 16 * C4NUM; + (packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0]; + (packed_input + channel_block_offset)[1] = (input_data + channel_block_stride)[1]; + (packed_input + channel_block_offset)[2] = (input_data + channel_block_stride)[2]; + (packed_input + channel_block_offset)[3] = (input_data + channel_block_stride)[3]; + } // channel_block loop + } // kernel_w loop + } // kernel_h loop + } // tile num loop +} + +void PackWeightFp16(float16_t *weight_data, ConvParameter *conv_param, float16_t *packed_weight) { + // original weight format : ohwi + int tile_num = 8; + int inchannel_block = 4; + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_channel = conv_param->input_channel_; + int out_channel = conv_param->output_channel_; + int kernel_block = UP_DIV(out_channel, tile_num); + int channel_block = UP_DIV(in_channel, inchannel_block); + int kernel_plane = kernel_h * kernel_w; + int pack_weight_size = kernel_block * channel_block * tile_num * inchannel_block * kernel_plane; + + int unit_size = tile_num * inchannel_block; + int block_size = pack_weight_size / kernel_block; + + for (int m = 0; m < kernel_plane; m++) { + int kernel_plane_stride = m * in_channel; + int packed_kernel_plane_stride = m * unit_size * channel_block; + for (int i = 0; i < channel_block; i++) { + int channel_block_stride = kernel_plane_stride + i * inchannel_block; + int packed_channel_block_size = packed_kernel_plane_stride + i * unit_size; + int ic_remainder = in_channel - i * inchannel_block; + int real_ic_num = ic_remainder < inchannel_block ? ic_remainder : inchannel_block; + for (int h = 0; h < real_ic_num; h++) { + int block_stride = channel_block_stride + h; + int packed_block_stride = packed_channel_block_size + h * tile_num; + for (int j = 0; j < kernel_block; j++) { + int kernel_block_stride = block_stride + j * tile_num * kernel_plane * in_channel; + int packed_kernel_block_size = packed_block_stride + j * block_size; + int oc_remainder = out_channel - j * tile_num; + int real_oc_num = oc_remainder < tile_num ? oc_remainder : tile_num; + for (int k = 0; k < real_oc_num; k++) { + float16_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel; + float16_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k; + *packed_data_ptr = *origin_data_ptr; + } + } // kernel block loop + } // inchannel block loop + } // channel block loop + } // kernel plane loop +} + +void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) { + // origin weight format : ohwi + int input_channel = conv_param->input_channel_; + int ic8 = UP_DIV(input_channel, C8NUM); + int output_channel = conv_param->output_channel_; + int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; + + for (int k = 0; k < kernel_plane; k++) { + int src_kernel_offset = k * input_channel; + int dst_kernel_offset = k * C8NUM; + for (int o = 0; o < output_channel; o++) { + int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; + int dst_oc_offset = dst_kernel_offset + o * ic8 * kernel_plane * C8NUM; + for (int i = 0; i < input_channel; i++) { + int c8_block_num = i / C8NUM; + int c8_block_rem = i % C8NUM; + int src_ic_offset = src_oc_offset + i; + int dst_ic_offset = dst_oc_offset + c8_block_num * kernel_plane * C8NUM + c8_block_rem; + (packed_weight_data + dst_ic_offset)[0] = (origin_weight_data + src_ic_offset)[0]; + } + } + } +} + +void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) { + // origin weight format : ohwi + int input_channel = conv_param->input_channel_; + int ic4 = UP_DIV(input_channel, C4NUM); + int output_channel = conv_param->output_channel_; + int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; + + for (int k = 0; k < kernel_plane; k++) { + int src_kernel_offset = k * input_channel; + int dst_kernel_offset = k * C4NUM; + for (int o = 0; o < output_channel; o++) { + int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; + int dst_oc_offset = dst_kernel_offset + o * ic4 * kernel_plane * C4NUM; + for (int i = 0; i < input_channel; i++) { + int c4_block_num = i / C4NUM; + int c4_block_rem = i % C4NUM; + int src_ic_offset = src_oc_offset + i; + int dst_ic_offset = dst_oc_offset + c4_block_num * kernel_plane * C4NUM + c4_block_rem; + (packed_weight_data + dst_ic_offset)[0] = (origin_weight_data + src_ic_offset)[0]; + } + } + } +} + +void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_oc_offset = b * plane * channel; + int dst_oc_offset = b * plane * c4 * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_oc_offset + k * channel; + int dst_kernel_offset = dst_oc_offset + k * C4NUM; + for (int i = 0; i < channel; i++) { + int c4_block_num = i / C4NUM; + int c4_block_rem = i % C4NUM; + int src_ic_offset = src_kernel_offset + i; + int dst_ic_offset = dst_kernel_offset + c4_block_num * plane * C4NUM + c4_block_rem; + ((float16_t *)dst + dst_ic_offset)[0] = ((float16_t *)src + src_ic_offset)[0]; + } + } + } +} + +void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c4 * C4NUM; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_rem = c % C4NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int ic4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = ic4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc4_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + memcpy((float16_t *)dst + nhwc4_batch_offset + i * ic4 * C4NUM, (float16_t *)src + batch_offset + i * channel, + channel * sizeof(float16_t)); + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float16_t); + memcpy(dst, src, ori_input_size); + } +} + +void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int nhwc4_batch_offset = 0; + int ic4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = ic4 * C4NUM * plane; + + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int c = 0; c < channel; c++) { + int src_c_offset = batch_offset + c * plane; + int dst_c_offset = nhwc4_batch_offset + c; + for (int i = 0; i < plane; i++) { + int src_plane_offset = src_c_offset + i; + int dst_plane_offset = dst_c_offset + i * ic4 * C4NUM; + ((float16_t *)dst)[dst_plane_offset] = ((float16_t *)src)[src_plane_offset]; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } +} + +void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * channel; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c * plane; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} +#endif + +void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed_weight) { + // original weight format : ohwi + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_channel = conv_param->input_channel_; + int out_channel = conv_param->output_channel_; + int oc8 = UP_DIV(out_channel, C8NUM); + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int pack_weight_size = oc8 * ic4 * C8NUM * C4NUM * kernel_plane; + + int unit_size = C8NUM * C4NUM; + int block_size = pack_weight_size / oc8; + + for (int m = 0; m < kernel_plane; m++) { + int kernel_plane_stride = m * in_channel; + int packed_kernel_plane_stride = m * unit_size * ic4; + for (int i = 0; i < ic4; i++) { + int channel_block_stride = kernel_plane_stride + i * C4NUM; + int packed_channel_block_size = packed_kernel_plane_stride + i * unit_size; + int ic_remainder = in_channel - i * C4NUM; + int real_ic_num = ic_remainder < C4NUM ? ic_remainder : C4NUM; + for (int h = 0; h < real_ic_num; h++) { + int block_stride = channel_block_stride + h; + int packed_block_stride = packed_channel_block_size + h * C8NUM; + for (int j = 0; j < oc8; j++) { + int kernel_block_stride = block_stride + j * C8NUM * kernel_plane * in_channel; + int packed_kernel_block_size = packed_block_stride + j * block_size; + int oc_remainder = out_channel - j * C8NUM; + int real_oc_num = oc_remainder < C8NUM ? oc_remainder : C8NUM; + for (int k = 0; k < real_oc_num; k++) { + float *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel; + float *packed_data_ptr = packed_weight + packed_kernel_block_size + k; + *packed_data_ptr = *origin_data_ptr; + } + } // kernel block loop + } // inchannel block loop + } // channel block loop + } // kernel plane loop +} + +void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum) { + // original weight format : ohwi + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_channel = conv_param->input_channel_; + int out_channel = conv_param->output_channel_; + int oc4 = UP_DIV(out_channel, C4NUM); + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int plane_c4 = UP_DIV(kernel_plane, C4NUM); + int pack_weight_size = oc4 * C4NUM * ic4 * C4NUM * plane_c4 * C4NUM; + int block_size = pack_weight_size / oc4; + + for (int m = 0; m < kernel_plane; m++) { + int kernel_plane_stride = m * in_channel; + int packed_kernel_plane_stride = m * C4NUM; + for (int i = 0; i < ic4; i++) { + int channel_block_stride = kernel_plane_stride + i * C4NUM; + int packed_channel_block_size = packed_kernel_plane_stride + i * C4NUM * C4NUM * C4NUM; + int ic_remainder = in_channel - i * C4NUM; + int real_ic_num = ic_remainder < C4NUM ? ic_remainder : C4NUM; + for (int h = 0; h < real_ic_num; h++) { + int block_stride = channel_block_stride + h; + int packed_block_stride = packed_channel_block_size + h; + for (int j = 0; j < oc4; j++) { + int kernel_block_stride = block_stride + j * C4NUM * kernel_plane * in_channel; + int packed_kernel_block_size = packed_block_stride + j * block_size; + int oc_remainder = out_channel - j * C4NUM; + int real_oc_num = oc_remainder < C4NUM ? oc_remainder : C4NUM; + for (int k = 0; k < real_oc_num; k++) { + int8_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel; + int8_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k * C4NUM * C4NUM; + *packed_data_ptr = origin_data_ptr[0]; + // value of weight must between [-127, 127] + if (packed_data_ptr[0] == -128) { + packed_data_ptr[0] = -127; + } + weight_sum[j * C4NUM + k] += (int32_t)packed_data_ptr[0]; + } + } // kernel block loop + } // inchannel block loop + } // channel block loop + } // kernel plane loop +} + +void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum) { + // original weight format : ohwi + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_channel = conv_param->input_channel_; + int out_channel = conv_param->output_channel_; + int oc4 = UP_DIV(out_channel, C4NUM); + int ic4 = UP_DIV(in_channel, C4NUM); + int kernel_plane = kernel_h * kernel_w; + int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; + int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane; + int unit_size = C4NUM * C4NUM; + int block_size = pack_weight_size / oc4; + + for (int m = 0; m < kernel_plane; m++) { + int kernel_plane_stride = m * in_channel; + int packed_kernel_plane_stride = m * unit_size * ic4; + for (int i = 0; i < ic4; i++) { + int channel_block_stride = kernel_plane_stride + i * C4NUM; + int packed_channel_block_size = packed_kernel_plane_stride + i * unit_size; + int ic_remainder = in_channel - i * C4NUM; + int real_ic_num = ic_remainder < C4NUM ? ic_remainder : C4NUM; + for (int h = 0; h < real_ic_num; h++) { + int block_stride = channel_block_stride + h; + int packed_block_stride = packed_channel_block_size + h; + for (int j = 0; j < oc4; j++) { + int kernel_block_stride = block_stride + j * C4NUM * kernel_plane * in_channel; + int packed_kernel_block_size = packed_block_stride + j * block_size; + int oc_remainder = out_channel - j * C4NUM; + int real_oc_num = oc_remainder < C4NUM ? oc_remainder : C4NUM; + for (int k = 0; k < real_oc_num; k++) { + int8_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel; + int8_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k * C4NUM; + *packed_data_ptr = origin_data_ptr[0]; + if (packed_data_ptr[0] == -128) { + packed_data_ptr[0] = -127; + } + weight_sum[j * C4NUM + k] += (int32_t)(packed_data_ptr[0] - filter_zp); + } + } // kernel block loop + } // inchannel block loop + } // channel block loop + } // kernel plane loop +} + +void Conv1x1InputPackFp32(const float *src, float *dst, ConvParameter *conv_param) { + for (int c = 0; c < UP_DIV(conv_param->input_channel_, C4NUM); c++) { + const float *src_c_ptr = src + c * conv_param->input_h_ * conv_param->input_w_ * C4NUM; + float *dst_c_ptr = dst + c * conv_param->output_h_ * conv_param->output_w_ * C4NUM; + for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) { + int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_h_; + if (src_h < 0 || src_h >= conv_param->input_h_) { + continue; + } + const float *src_h_ptr = src_c_ptr + src_h * conv_param->input_w_ * C4NUM; + float *dst_h_ptr = dst_c_ptr + dst_h * conv_param->output_w_ * C4NUM; + for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) { + int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_w_; + if (src_w < 0 || src_w >= conv_param->input_w_) { + continue; + } + memcpy(dst_h_ptr + dst_w * C4NUM, src_h_ptr + src_w * C4NUM, C4NUM * sizeof(float)); + } + } + } + return; +} + +void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param) { + int c4 = UP_ROUND(conv_param->input_channel_, C4NUM); + for (int ic = 0; ic < conv_param->input_channel_; ic++) { + for (int oc = 0; oc < conv_param->output_channel_; oc++) { + int oc4mod = oc % 4; + int oc4div = oc / 4; + int dst_index = oc4div * c4 * C4NUM + ic * C4NUM + oc4mod; + int src_index = oc * conv_param->input_channel_ + ic; + packed_weight[dst_index] = weight_data[src_index]; + } + } + return; +} + +void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, float *packed_input, int real_cal_num, + int block_index) { + // input format : nhwc + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int pad_h = conv_param->pad_h_; + int pad_w = conv_param->pad_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int ic4 = UP_DIV(in_channel, C4NUM); + + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * stride_h - pad_h; + int input_w = block_start % out_w * stride_w - pad_w; + for (int j = 0; j < kernel_h; j++) { + int input_y = input_h + j * dilation_h; + if (input_y < 0 || input_y >= in_h) { + continue; + } + int input_y_stride = input_y * in_w * ic4 * C4NUM; + for (int n = 0; n < kernel_w; n++) { + int input_x = input_w + n * dilation_w; + if (input_x < 0 || input_x >= in_w) { + continue; + } + int input_x_stride = input_y_stride + input_x * ic4 * C4NUM; + int input_plane_offset = (j * kernel_w + n) * C8NUM * C4NUM * ic4 + i * C4NUM; + for (int m = 0; m < ic4; m++) { + int channel_block_stride = input_x_stride + m * C4NUM; + int channel_block_offset = input_plane_offset + m * C8NUM * C4NUM; +#ifdef ENABLE_NEON + vst1q_f32(packed_input + channel_block_offset, vld1q_f32(input_data + channel_block_stride)); +#else + (packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0]; + (packed_input + channel_block_offset)[1] = (input_data + channel_block_stride)[1]; + (packed_input + channel_block_offset)[2] = (input_data + channel_block_stride)[2]; + (packed_input + channel_block_offset)[3] = (input_data + channel_block_stride)[3]; +#endif + } // channel_block loop + } // kernel_w loop + } // kernel_h loop + } // tile num loop +} + +void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, + int32_t *input_sum, ConvParameter *conv_param) { + // input format : nhwc + int tile_num = 4; + int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int pad_h = conv_param->pad_h_; + int pad_w = conv_param->pad_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int ic4 = UP_DIV(in_channel, C4NUM); + int out_w = conv_param->output_w_; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * stride_h - pad_h; + int input_w = block_start % out_w * stride_w - pad_w; + int input_cal_num_offset = i * C4NUM * C4NUM; + int32_t input_accumulator = 0; + for (int j = 0; j < kernel_h; j++) { + int input_y = input_h + j * dilation_h; + if (input_y < 0 || input_y >= in_h) { + continue; + } + int input_y_stride = input_y * in_w * ic4 * C4NUM; + for (int n = 0; n < kernel_w; n++) { + int input_x = input_w + n * dilation_w; + if (input_x < 0 || input_x >= in_w) { + continue; + } + int input_x_stride = input_y_stride + input_x * ic4 * C4NUM; + int plane_c4_block = (j * kernel_w + n) / C4NUM; + int plane_c4_res = (j * kernel_w + n) % C4NUM; + int input_plane_offset = + plane_c4_block * tile_num * C4NUM * C4NUM * ic4 + plane_c4_res * C4NUM + input_cal_num_offset; + for (int m = 0; m < ic4; m++) { + int channel_block_stride = input_x_stride + m * C4NUM; + int channel_block_offset = input_plane_offset + m * tile_num * C4NUM * C4NUM; + (packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0]; + (packed_input + channel_block_offset)[1] = (input_data + channel_block_stride)[1]; + (packed_input + channel_block_offset)[2] = (input_data + channel_block_stride)[2]; + (packed_input + channel_block_offset)[3] = (input_data + channel_block_stride)[3]; + input_accumulator += (packed_input + channel_block_offset)[0]; + input_accumulator += (packed_input + channel_block_offset)[1]; + input_accumulator += (packed_input + channel_block_offset)[2]; + input_accumulator += (packed_input + channel_block_offset)[3]; + } // channel_block loop + } // kernel_w loop + } // kernel_h loop + input_sum[i] = input_accumulator * filter_zp; + } // tile num loop +} + +void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, + int32_t *input_sum, ConvParameter *conv_param) { + // input format : nhwc + int tile_num = 24; + int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int pad_h = conv_param->pad_h_; + int pad_w = conv_param->pad_w_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int ic4 = UP_DIV(in_channel, C4NUM); + int out_w = conv_param->output_w_; + int block_size = kernel_h * kernel_w; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * stride_h - pad_h; + int input_w = block_start % out_w * stride_w - pad_w; + for (int j = 0; j < kernel_h; j++) { + int input_y = input_h + j * dilation_h; + if (input_y < 0 || input_y >= in_h) { + continue; + } + int input_y_stride = input_y * in_w * ic4 * C4NUM; + for (int n = 0; n < kernel_w; n++) { + int input_x = input_w + n * dilation_w; + if (input_x < 0 || input_x >= in_w) { + continue; + } + int input_x_stride = input_y_stride + input_x * ic4 * C4NUM; + int input_plane_offset = (j * kernel_w + n) * tile_num * C4NUM * ic4 + i * C4NUM; + for (int m = 0; m < ic4; m++) { + int channel_block_stride = input_x_stride + m * C4NUM; + int channel_block_offset = input_plane_offset + m * tile_num * C4NUM; + (packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0]; + (packed_input + channel_block_offset)[1] = (input_data + channel_block_stride)[1]; + (packed_input + channel_block_offset)[2] = (input_data + channel_block_stride)[2]; + (packed_input + channel_block_offset)[3] = (input_data + channel_block_stride)[3]; + } // channel_block loop + } // kernel_w loop + } // kernel_h loop + int32_t input_accumulator = 0; + for (int j = 0; j < block_size; j++) { + int block_offset = j * tile_num * ic4 * C4NUM + i * C4NUM; + for (int c = 0; c < ic4; c++) { + int ic4_offset = block_offset + c * tile_num * C4NUM; + input_accumulator += (packed_input + ic4_offset)[0]; + input_accumulator += (packed_input + ic4_offset)[1]; + input_accumulator += (packed_input + ic4_offset)[2]; + input_accumulator += (packed_input + ic4_offset)[3]; + } + } + input_sum[i] = input_accumulator * filter_zp; + } // tile num loop +} + +void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param) { + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int ic8 = UP_DIV(in_channel, C8NUM); + + for (int b = 0; b < in_batch; b++) { + int src_batch_offset = b * in_channel * in_h * in_w; + int dst_batch_offset = b * ic8 * C8NUM * in_h * in_w; + for (int c = 0; c < in_channel; c++) { + int ic8_block = c / C8NUM; + int ic8_res = c % C8NUM; + int src_c_offset = src_batch_offset + c; + int dst_c_offset = dst_batch_offset + ic8_block * C8NUM * in_h * in_w + ic8_res; + for (int k = 0; k < in_w * in_h; k++) { + int src_plane_offset = src_c_offset + k * in_channel; + int dst_plane_offset = dst_c_offset + k * C8NUM; + (packed_input + dst_plane_offset)[0] = (int16_t)(input_data + src_plane_offset)[0]; + } + } + } +} + +void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param) { + // origin weight format : ohwi + int input_channel = conv_param->input_channel_; + int ic8 = UP_DIV(input_channel, C8NUM); + int output_channel = conv_param->output_channel_; + int filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; + int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; + + for (int k = 0; k < kernel_plane; k++) { + int src_kernel_offset = k * input_channel; + int dst_kernel_offset = k * C8NUM; + for (int o = 0; o < output_channel; o++) { + int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; + int dst_oc_offset = dst_kernel_offset + o * ic8 * kernel_plane * C8NUM; + for (int i = 0; i < input_channel; i++) { + int c8_block_num = i / C8NUM; + int c8_block_rem = i % C8NUM; + int src_ic_offset = src_oc_offset + i; + int dst_ic_offset = dst_oc_offset + c8_block_num * kernel_plane * C8NUM + c8_block_rem; + (packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - filter_zp); + } + } + } +} + +void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_oc_offset = b * plane * channel; + int dst_oc_offset = b * plane * c4 * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_oc_offset + k * channel; + int dst_kernel_offset = dst_oc_offset + k * C4NUM; + for (int i = 0; i < channel; i++) { + int c4_block_num = i / C4NUM; + int c4_block_rem = i % C4NUM; + int src_ic_offset = src_kernel_offset + i; + int dst_ic_offset = dst_kernel_offset + c4_block_num * plane * C4NUM + c4_block_rem; + ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0]; + } + } + } +} + +void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c4 * C4NUM; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_rem = c % C4NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem; + ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc4_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + memcpy((float *)dst + nhwc4_batch_offset + i * c4 * C4NUM, (float *)src + batch_offset + i * channel, + channel * sizeof(float)); + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy(dst, src, ori_input_size); + } +} + +void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc_batch_unit_offset = channel * plane; + for (int b = 0; b < batch; b++) { + int batch_offset = b * c4 * C4NUM * plane; + for (int i = 0; i < plane; i++) { + memcpy((float *)dst + b * nhwc_batch_unit_offset + i * channel, (float *)src + batch_offset + i * c4 * C4NUM, + channel * sizeof(float)); + } + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy(dst, src, ori_input_size); + } +} + +void PackNCHWToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int nhwc4_batch_offset = 0; + int c4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int c = 0; c < channel; c++) { + int src_c_offset = batch_offset + c * plane; + int dst_c_offset = nhwc4_batch_offset + c; + for (int i = 0; i < plane; i++) { + int src_plane_offset = src_c_offset + i; + int dst_plane_offset = dst_c_offset + i * c4 * C4NUM; + ((float *)dst)[dst_plane_offset] = ((float *)src)[src_plane_offset]; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } +} + +void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; + ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C4NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c4 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C4NUM; + int dst_c_offset = dst_kernel_offset + c * C4NUM; +#ifdef ENABLE_NEON + vst1q_f32((float *)dst + dst_c_offset, vld1q_f32((float *)src + src_c_offset)); +#else + ((float *)dst + dst_c_offset)[0] = ((float *)src + src_c_offset)[0]; + ((float *)dst + dst_c_offset)[1] = ((float *)src + src_c_offset)[1]; + ((float *)dst + dst_c_offset)[2] = ((float *)src + src_c_offset)[2]; + ((float *)dst + dst_c_offset)[3] = ((float *)src + src_c_offset)[3]; +#endif + } + // res part + int res_c = channel - (c4 - 1) * C4NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; + ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c * plane; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k; + ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc4_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + memcpy((int8_t *)dst + nhwc4_batch_offset + i * c4 * C4NUM, (int8_t *)src + batch_offset + i * channel, + channel); + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy(dst, src, ori_input_size); + } +} + +void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + int nhwc4_batch_offset = b * nhwc4_batch_unit_offset; + for (int i = 0; i < plane; i++) { + memcpy(reinterpret_cast(dst) + batch_offset + i * channel, + reinterpret_cast(src) + nhwc4_batch_offset + i * c4 * C4NUM, channel); + } + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy(dst, src, ori_input_size); + } +} + +void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) { + int nhwc4_batch_offset = 0; + int c4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int c = 0; c < channel; c++) { + int src_c_offset = batch_offset + c * plane; + int dst_c_offset = nhwc4_batch_offset + c; + for (int i = 0; i < plane; i++) { + int src_plane_offset = src_c_offset + i; + int dst_plane_offset = dst_c_offset + i * c4 * C4NUM; + ((uint8_t *)dst)[dst_plane_offset] = ((uint8_t *)src)[src_plane_offset]; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } +} + +void PackNC4HW4ToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; + ((uint8_t *)dst + dst_kernel_offset)[0] = ((uint8_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * channel; + ((uint8_t *)dst + dst_kernel_offset)[0] = ((uint8_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c * plane; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k; + ((uint8_t *)dst + dst_kernel_offset)[0] = ((uint8_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int hw = 0; hw < plane; hw++) { + for (int c = 0; c < channel; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int src_index = n * plane * channel + hw * channel + c; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + ((int8_t *)dst)[dst_index] = ((int8_t *)src)[src_index]; + } + } + } + return; +} + +void PackNHWCToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_oc_offset = b * plane * channel; + int dst_oc_offset = b * plane * c8 * C8NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_oc_offset + k * channel; + int dst_kernel_offset = dst_oc_offset + k * C8NUM; + for (int i = 0; i < channel; i++) { + int c8_block_num = i / C8NUM; + int c8_block_rem = i % C8NUM; + int src_ic_offset = src_kernel_offset + i; + int dst_ic_offset = dst_kernel_offset + c8_block_num * plane * C8NUM + c8_block_rem; + ((int8_t *)dst + dst_ic_offset)[0] = ((int8_t *)src + src_ic_offset)[0]; + } + } + } +} + +void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int c = 0; c < channel; c++) { + for (int hw = 0; hw < plane; hw++) { + int nhwc_index = n * channel * plane + hw * channel + c; + int nchw_index = n * channel * plane + c * plane + hw; + ((int8_t *)(dst))[nhwc_index] = ((const int8_t *)(src))[nchw_index]; + } + } + } + return; +} + +void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int c = 0; c < channel; c++) { + for (int hw = 0; hw < plane; hw++) { + int nhwc_index = n * channel * plane + hw * channel + c; + int nchw_index = n * channel * plane + c * plane + hw; + ((float *)dst)[nchw_index] = ((float *)src)[nhwc_index]; + } + } + } + return; +} + +void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int c = 0; c < channel; c++) { + for (int hw = 0; hw < plane; hw++) { + int nhwc_index = n * channel * plane + hw * channel + c; + int nchw_index = n * channel * plane + c * plane + hw; + ((float *)dst)[nhwc_index] = ((float *)src)[nchw_index]; + } + } + } + return; +} + +void MatrixPackUnit(const float *src, float *dst, size_t row, size_t col, size_t src_stride, size_t dst_stride) { + size_t copy_size = row * C4NUM * sizeof(float); + for (int c = 0; c < col; c++) { + memcpy(dst + c * dst_stride, src + c * src_stride, copy_size); + } +} + +void MatrixPack(const float *src, float *dst, int row, int ic4, int stride) { + int row4mod = row % 4; + int row4div = row / 4; + + for (int i = 0; i < row4div; i++) { + MatrixPackUnit(src + i * 4 * 4, dst + i * 4 * ic4 * 4, 4, ic4, stride, 16); + } + + if (row4mod > 0) { + MatrixPackUnit(src + row4div * 4 * 4, dst + row4div * 4 * ic4 * 4, row4mod, ic4, stride, row4mod * 4); + } + return; +} + +void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param) { + auto input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_; + int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); + int unit = conv_param->input_h_ * conv_param->input_w_; + + for (int b = 0; b < conv_param->input_batch_; b++) { + auto src_b = src + b * unit * conv_param->input_channel_; + auto dst_b = dst + b * unit * ic4 * C4NUM; + for (int k = 0; k < unit; k++) { + auto src_k = src_b + k * conv_param->input_channel_; + auto dst_k = dst_b + k * ic4 * C4NUM; + for (int c = 0; c < conv_param->input_channel_; c++) { + dst_k[c] = (int16_t)((int32_t)(src_k[c]) - input_zp); + } + } + } +} + +void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, const ConvParameter *conv_param) { + auto weight_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_; + int unit = conv_param->kernel_h_ * conv_param->kernel_w_; + for (int c = 0; c < conv_param->output_channel_; c++) { + int c4_block_num = c / C4NUM; + int c4_block_rem = c % C4NUM; + auto src_c = origin_weight + c * unit; + auto dst_c = packed_weight_ + c4_block_num * unit * C4NUM; + for (int k = 0; k < unit; k++) { + auto src_kernel = src_c + k; + auto dst_kernel = dst_c + C4NUM * k + c4_block_rem; + *dst_kernel = (int16_t)((int32_t)(src_kernel[0]) - weight_zp); + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/pack.h b/mindspore/lite/src/runtime/kernel/arm/opclib/pack.h new file mode 100644 index 00000000000..5bc02976e40 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/pack.h @@ -0,0 +1,166 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PACK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PACK_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "src/runtime/kernel/arm/opclib/conv_parameter.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" + +#ifdef ENABLE_FP16 +void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num, + int block_index); + +void PackWeightFp16(float16_t *weight_data, ConvParameter *conv_param, float16_t *packed_weight); + +void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param); + +void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param); + +void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel); +#endif +void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, float *packed_input, int real_cal_num, + int block_index); + +void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, + int32_t *input_sum, ConvParameter *conv_param); + +void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, + int32_t *input_sum, ConvParameter *conv_param); + +void Conv1x1InputPackFp32(const float *src, float *dst, ConvParameter *conv_param); + +void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param); + +void MatrixPack(const float *src, float *dst, int row, int ic4, int stride); + +void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param); + +void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed_weight); + +void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum); + +void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum); + +void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param); + +void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWCToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); + +void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param); + +void PackDepthwiseInt8Weight(const int8_t *src, int16_t *dst, const ConvParameter *conv_param); + +inline void UnpackHwcToChwFp32(float *src_ptr, float *dst_ptr, int channel, int h, int w) { + int cur = 0; + for (int i = 0; i < channel; i++) { + auto plane = i / BLOCK; + auto offset = i % BLOCK; + auto src_plane = plane * h * w * BLOCK + src_ptr; + for (int j = 0; j < h * w; j++) { + dst_ptr[cur++] = src_plane[j * BLOCK + offset]; + } + } +} + +inline void C8UnpackToHwcFp32(float *src_ptr, float *dst_ptr, int channel, int h, int w) { + int cur = 0; + for (int j = 0; j < h * w; j++) { + for (int i = 0; i < channel; i++) { + auto plane = i / 8; + auto offset = i % 8; + auto src_plane = plane * h * w * 8 + src_ptr; + dst_ptr[cur++] = src_plane[j * 8 + offset]; + } + } +} + +inline void C4UnpackToHwcFp32(float *src_ptr, float *dst_ptr, int channel, int h, int w) { + int cur = 0; + for (int j = 0; j < h * w; j++) { + for (int i = 0; i < channel; i++) { + auto plane = i / 4; + auto offset = i % 4; + auto src_plane = plane * h * w * 4 + src_ptr; + dst_ptr[cur++] = src_plane[j * 4 + offset]; + } + } +} + +inline void C4UnpackToHwcInt8(int8_t *src_ptr, int8_t *dst_ptr, int channel, int h, int w) { + int cur = 0; + for (int j = 0; j < h * w; j++) { + for (int i = 0; i < channel; i++) { + auto plane = i / 4; + auto offset = i % 4; + auto src_plane = plane * h * w * 4 + src_ptr; + dst_ptr[cur++] = src_plane[j * 4 + offset]; + } + } +} + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PACK_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/pad.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/pad.cc new file mode 100644 index 00000000000..6ad445db495 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/pad.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/opclib/pad.h" +#include +#include "src/runtime/kernel/arm/opclib/offset_utils.h" + +void Pad(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, + const int *paddings, const int tid, const int thread_num) { + int in[4], out[4]; + for (in[0] = 0; in[0] < input_shape[0]; in[0]++) { + out[0] = in[0] + paddings[0]; + for (in[1] = tid; in[1] < input_shape[1]; in[1] += thread_num) { + out[1] = in[1] + paddings[2]; + for (in[2] = 0; in[2] < input_shape[2]; in[2]++) { + out[2] = in[2] + paddings[4]; + for (in[3] = 0; in[3] < input_shape[3]; in[3]++) { + out[3] = in[3] + paddings[6]; + output_data[offset4d(output_shape, out)] = input_data[offset4d(input_shape, in)]; + } + } + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/pad.h b/mindspore/lite/src/runtime/kernel/arm/opclib/pad.h new file mode 100644 index 00000000000..7b3afe9b337 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/pad.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PAD_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct PadParameter { + OpParameter op_parameter_; + int paddings[8]; + size_t ori_size_; +}; + +void Pad(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, + const int *paddings, const int tid, const int thread_num); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PAD_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/power.h b/mindspore/lite/src/runtime/kernel/arm/opclib/power.h new file mode 100644 index 00000000000..6605ba3c0ae --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/power.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_POWER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_POWER_H_ +#include +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct PowerParameter { + OpParameter op_parameter_; + float power_; + float scale_; + float shift_; +}; + +inline void Power(const float *input_data, float *output_data, int len, float power, float scale, float shift) { + for (int i = 0; i < len; ++i) { + output_data[i] = pow((scale * input_data[i] + shift), power); + } +} + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_POWER_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/prelu.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/prelu.cc new file mode 100644 index 00000000000..3dff4230f02 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/prelu.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/opclib/prelu.h" + +void PRelu(float *input, float *output, PReluParameter *prelu_param_, int task_id) { + for (int i = task_id; i < prelu_param_->input_num_; i += prelu_param_->op_parameter_.thread_num_) { + if (input[i] <= 0) { + output[i] = input[i] * prelu_param_->negtive_slope_[0]; + } else { + output[i] = input[i]; + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/prelu.h b/mindspore/lite/src/runtime/kernel/arm/opclib/prelu.h new file mode 100644 index 00000000000..d392ca1d5f4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/prelu.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PRELU_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PRELU_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct PReluParameter { + OpParameter op_parameter_; + float *negtive_slope_; + int input_num_; + int thread_num_; +}; + +void PRelu(float *input, float *output, PReluParameter *prelu_param_, int task_id); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PRELU_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/fixed_point.h b/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/fixed_point.h new file mode 100644 index 00000000000..3249d2e47c4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/fixed_point.h @@ -0,0 +1,687 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_QUANTIZATION_FIXED_POINT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_QUANTIZATION_FIXED_POINT_H_ + +#include +#include +#include +#include +#include +#ifdef ENABLE_NEON +#include +#endif + +// Part 1: Low-level integer-arithmetic primitives. +// The implementations here are generic implementations valid for +// scalar types (e.g. std::int32_t). Architecture-specific SIMD types +// (e.g. NEON int32x4_t) may be supported by providing +// specializations for them in separate files. +// +// The purpose of these primitives is two-fold: +// - They will be used to implement higher-level fixed-point +// abstractions, namely the FixedPoint class and its arithmetic +// operators. +// - They will be directly used to implement some more involved +// fixed-point computations, e.g. the fixed-point implementation +// of math functions such as tanh. + +// Some compile-time traits around raw types to handle SIMD aspects: +// number of lanes, underlying scalar type. +template +struct FixedPointRawTypeTraits {}; + +template <> +struct FixedPointRawTypeTraits { + typedef std::int32_t ScalarRawType; + static constexpr int kLanes = 1; +}; + +template <> +struct FixedPointRawTypeTraits { + typedef std::int16_t ScalarRawType; + static constexpr int kLanes = 1; +}; + +// Returns a SIMD value duplicating a scalar value across all lanes. +template +tRawType Dup(typename FixedPointRawTypeTraits::ScalarRawType x) { + return x; +} + +// Plain bit-wise AND +template +tIntegerType BitAnd(tIntegerType a, tIntegerType b) { + return a & b; +} + +// Plain bit-wise OR +template +tIntegerType BitOr(tIntegerType a, tIntegerType b) { + return a | b; +} + +// Plain bit-wise XOR +template +tIntegerType BitXor(tIntegerType a, tIntegerType b) { + return a ^ b; +} + +// Plain bit-wise NOT +template +tIntegerType BitNot(tIntegerType a) { + return ~a; +} + +// Integer addition. Not saturating. Overflow is undefined behavior. +template +tIntegerType Add(tIntegerType a, tIntegerType b) { + return a + b; +} + +// Integer multiplication. Not saturating. Overflow is undefined behavior. +template +tIntegerType Mul(tIntegerType a, tIntegerType b) { + return a * b; +} + +// Integer subtraction. Not saturating. Overflow is undefined behavior. +template +tIntegerType Sub(tIntegerType a, tIntegerType b) { + return a - b; +} + +// Integer unary negative. Not saturating. Overflow is undefined behavior. +template +tIntegerType Neg(tIntegerType a) { + return -a; +} + +// Integer arithmetic left-shift, equivalent to multiplying with a power of two. +// Negative values are OK. In case of overflow, no Undefined +// Behavior, but the results are implementation-defined (in practice, +// they currently are saturated, but we make no commitment to that). The idea +// is that the caller will want to implement the overflowing cases with +// saturation with compare-and-mask, so we don't care about the results +// in the overflow case, we just want to avoid undefined behavior. +// +// tIntegerType may be int32 or any narrower signed type. +template +tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) { + const std::int64_t wide_a = (std::int64_t)(a); + const std::int64_t wide_shifted = wide_a * (1 << offset); + const auto min = std::numeric_limits::min(); + const auto max = std::numeric_limits::max(); + return wide_shifted < min ? min : wide_shifted > max ? max : (tIntegerType)(wide_shifted); +} + +// Integer arithmetic right-shift. Not rounding. +// Relying on implementation-defined, but in-practice-consistent, +// C++ compiler behavior. +template +tIntegerType ShiftRight(tIntegerType a, int offset) { + return a >> offset; +} + +// Each bit of the result is set to the corresponding bit of either then_val or +// else_val depending on whether the corresponding bit of if_mask is set. +// Equivalent to the VBSL instruction in ARM NEON. +template +tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, tIntegerType else_val) { + return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val)); +} + +// For each input scalar, the corresponding bits of the result are set if the +// input scalar is non-zero. +template +tIntegerType MaskIfNonZero(tIntegerType a) { + static constexpr tIntegerType zero = 0; + return a ? BitNot(zero) : zero; +} + +// For each input scalar, the corresponding bits of the result are set if the +// input scalar is zero. +template +tIntegerType MaskIfZero(tIntegerType a) { + return MaskIfNonZero(!a); +} + +// For each pair of input scalars, the corresponding bits of the result are +// set if the input scalars are equal. +template +tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) { + return MaskIfNonZero(a == b); +} + +// For each pair of input scalars, the corresponding bits of the result are +// set if the input scalars are not equal. +template +tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) { + return MaskIfNonZero(a != b); +} + +// For each pair of input scalars, the corresponding bits of the result are +// set if the input scalars a, b satisfy a > b. +template +tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) { + return MaskIfNonZero(a > b); +} + +// For each pair of input scalars, the corresponding bits of the result are +// set if the input scalars a, b satisfy a >= b. +template +tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) { + return MaskIfNonZero(a >= b); +} + +// For each pair of input scalars, the corresponding bits of the result are +// set if the input scalars a, b satisfy a < b. +template +tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) { + return MaskIfNonZero(a < b); +} + +// For each pair of input scalars, the corresponding bits of the result are +// set if the input scalars a, b satisfy a <= b. +template +tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) { + return MaskIfNonZero(a <= b); +} + +// Returns true if all of the input scalars are nonzero. +// This function may currently assume that each of the input scalars has either +// all or none of its bits set. Otherwise, its behavior is currently undefined. +template +bool All(tIntegerType a) { + return a; +} + +// Returns true if any of the input scalars are nonzero. +// This function may currently assume that each of the input scalars has either +// all or none of its bits set. Otherwise, its behavior is currently undefined. +template +bool Any(tIntegerType a) { + return a; +} + +// Returns (a+b)/2, rounded to the nearest integer. +// Equivalent to VRHADD in the ARM NEON instruction set. +template +IntegerType RoundingHalfSum(IntegerType a, IntegerType b) { + static_assert(std::is_same::value, "unimplemented"); + (void)b; + return a; +} + +template <> +inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) { + std::int64_t a64 = a; + std::int64_t b64 = b; + std::int64_t sum = a64 + b64; + std::int64_t sign = sum >= 0 ? 1 : -1; + return (std::int32_t)((sum + sign) / 2); +} + +template <> +inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) { + std::int32_t a32 = a; + std::int32_t b32 = b; + std::int32_t sum = a32 + b32; + std::int32_t sign = sum >= 0 ? 1 : -1; + return (std::int16_t)((sum + sign) / 2); +} + +template +IntegerType SaturatingAdd(IntegerType a, IntegerType b) { + static_assert(std::is_same::value, "unimplemented"); + (void)b; + return a; +} + +// So far this is only needed for int16. +template <> +inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) { + std::int32_t a32 = a; + std::int32_t b32 = b; + std::int32_t sum = a32 + b32; + return (std::int16_t)(std::min((std::int32_t)(32767), std::max((std::int32_t)(-32768), sum))); +} + +template <> +inline std::int8_t SaturatingAdd(std::int8_t a, std::int8_t b) { + std::int16_t a16 = a; + std::int16_t b16 = b; + std::int16_t sum = a16 + b16; + return (std::int8_t)(std::min((int16_t)(std::numeric_limits::max()), + std::max((int16_t)(std::numeric_limits::min()), sum))); +} + +// Returns a+b, saturating if the integers are 16bit or narrower, +// otherwise just a plain addition. +template +struct AddSaturatingIf16BitImpl { + static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); } +}; +template +struct AddSaturatingIf16BitImpl { + static IntegerType Run(IntegerType a, IntegerType b) { return SaturatingAdd(a, b); } +}; +template +IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) { + using ScalarType = typename FixedPointRawTypeTraits::ScalarRawType; + return AddSaturatingIf16BitImpl::Run(a, b); +} + +// Returns the integer that represents the product of two fixed-point +// numbers, interpreting all integers as fixed-point values in the +// interval [-1, 1), rounding to the nearest value, and saturating +// -1 * -1 to the maximum value (since 1 is not in the half-open +// interval [-1, 1)). +// +// [The explanation below specializes to std::int32_t for example purpose.] +// +// The mapping between IntegerType and the interval [-1, 1) is unique and +// implied by IntegerType, which is assumed to be signed. For example, +// for IntegerType==std::int32_t, the mapping is +// real_value = integer_value / 2^31. +// So in this case, and leaving aside rounding and saturating, this +// function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to +// (a * b) / 2^31. +// +// The 'doubling' part in the name of this function comes from the fact that +// this operation is very close to a "multiply-high" operation, keeping only +// the top half bits, except that that would be effectively computing +// (a * b) / 2^32, +// so here we are computing 2x that, since +// 1/2^31 = 2 * 1/2^32. +// The idea is to use all of the available 32 bits in the destination int32 +// value. +// +// [End of the explanation specializing to int32.] +// +// This is equivalent to the VQRDMULH instruction in ARM NEON. +template +IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) { + static_assert(std::is_same::value, "unimplemented"); + (void)b; + return a; +} + +// This function implements the same computation as the ARMv7 NEON VQRDMULH +// instruction. +template <> +inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, std::int32_t b) { + bool overflow = a == b && a == std::numeric_limits::min(); + std::int64_t a_64(a); + std::int64_t b_64(b); + std::int64_t ab_64 = a_64 * b_64; + std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); + std::int32_t ab_x2_high32 = (std::int32_t)((ab_64 + nudge) / (1ll << 31)); + return overflow ? std::numeric_limits::max() : ab_x2_high32; +} + +template <> +inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a, std::int16_t b) { + bool overflow = a == b && a == std::numeric_limits::min(); + std::int32_t a_32(a); + std::int32_t b_32(b); + std::int32_t ab_32 = a_32 * b_32; + std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14)); + std::int16_t ab_x2_high16 = (std::int16_t)((ab_32 + nudge) / (1 << 15)); + return overflow ? std::numeric_limits::max() : ab_x2_high16; +} + +// Correctly-rounded-to-nearest division by a power-of-two. +// Also known as a rounding arithmetic right shift. +template +inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) { + assert(exponent >= 0); + assert(exponent <= 31); + const IntegerType mask = Dup((1ll << exponent) - 1); + const IntegerType zero = Dup(0); + const IntegerType one = Dup(1); + const IntegerType remainder = BitAnd(x, mask); + const IntegerType threshold = Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one)); + return Add(ShiftRight(x, exponent), BitAnd(MaskIfGreaterThan(remainder, threshold), one)); +} + +inline int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift) { + return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); +} + +// Returns the product of a run-time integer value by a compile-time power +// of two, with either a positive exponent (equivalent to an arithmetic +// left shift, saturating) or a negative exponent (equivalent to an arithmetic +// right shift, rounding to nearest). +template 0 ? 1 : Exponent < 0 ? -1 : 0)> +struct ImplSaturatingRoundingMultiplyByPOT {}; + +template +struct ImplSaturatingRoundingMultiplyByPOT { + static IntegerType eval(IntegerType x) { return x; } +}; + +template +struct ImplSaturatingRoundingMultiplyByPOT { + static IntegerType eval(IntegerType x) { + using ScalarIntegerType = typename FixedPointRawTypeTraits::ScalarRawType; + const IntegerType min = Dup(std::numeric_limits::min()); + const IntegerType max = Dup(std::numeric_limits::max()); + const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); + + const std::int32_t threshold = ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1); + const IntegerType positive_mask = MaskIfGreaterThan(x, Dup(threshold)); + const IntegerType negative_mask = MaskIfLessThan(x, Dup(-threshold)); + + IntegerType result = ShiftLeft(x, Exponent); + result = SelectUsingMask(positive_mask, max, result); + result = SelectUsingMask(negative_mask, min, result); + return result; + } +}; + +template +struct ImplSaturatingRoundingMultiplyByPOT { + static IntegerType eval(IntegerType x) { return RoundingDivideByPOT(x, -Exponent); } +}; + +template +IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) { + return ImplSaturatingRoundingMultiplyByPOT::eval(x); +} + +// Part 2: the FixedPoint class. + +// A FixedPoint object represents a fixed-point value stored in the underlying +// integer type tRawType, if tRawType is a plain scalar integer type. +// Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which +// case a FixedPoint object represents a corresponding SIMD vector of fixed +// point values. +// +// tIntegerBits describes the range of the fixed-point format: if +// tIntegerBits == m then the range of representable values is the half-open +// interval [-2^m; 2^m) where the open boundary on the right side means that +// 2^m is not representable (how close the maximum representable value is to +// it, depends on bit-depth of tRawType). +// +// In "Q format notation", +// https://en.wikipedia.org/wiki/Q_(number_format) +// we are describing the format +// Qm.n +// where +// m = tIntegerBits +// and +// n = NumberOfBits(tRawType) - (m + 1) +// Note that the (m + 1) in the above line is because we adopt the convention +// that we count the integer bits exclusively of the sign bit; so (m + 1) is +// the total number of integer bits inclusive of the sign bit. +// +// Accordingly, the number of integral representable values in our range +// [-2^m ; 2^m) +// is equal to 2^(m+1). +template +class FixedPoint { + public: + typedef tRawType RawType; + + typedef FixedPointRawTypeTraits RawTypeTraits; + typedef typename RawTypeTraits::ScalarRawType ScalarRawType; + + static constexpr int kTotalBits = 8 * sizeof(ScalarRawType); + static constexpr int kIntegerBits = tIntegerBits; + static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits; + static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, "bad IntegerBits"); + + typedef FixedPoint ScalarFixedPointType; + + static const ScalarRawType ScalarRawMin() { return std::numeric_limits::min(); } + + static const ScalarRawType ScalarRawMax() { return std::numeric_limits::max(); } + + static const ScalarRawType RawMin() { return VectorFromScalar(ScalarRawMin()); } + + static const ScalarRawType RawMax() { return VectorFromScalar(ScalarRawMax()); } + + static FixedPoint FromRaw(RawType x) { + FixedPoint retval; + retval.raw() = x; + return retval; + } + + static FixedPoint FromScalarRaw(ScalarRawType x) { + FixedPoint retval; + retval.raw() = Dup(x); + return retval; + } + + static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) { return FromScalarRaw(x.raw()); } + + template + static FixedPoint ConstantPOT() { + static constexpr int kOffset = kFractionalBits + Exponent; + static_assert(kOffset < 31, "Constant not exactly representable in this fixed-point format"); + return FromScalarRaw(ScalarRawType(1) << kOffset); + } + + static FixedPoint Zero() { return FromScalarRaw(0); } + + static FixedPoint One() { + return FromScalarRaw(kIntegerBits == 0 ? ScalarRawMax() + : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits))); + } + + static FixedPoint FromDouble(double x) { + const double min_bound = (double)(ScalarRawMin()); + const double max_bound = (double)(ScalarRawMax()); + return FromScalarRaw( + (ScalarRawType)(std::min(std::max(round(x * (double)(1ll << kFractionalBits)), min_bound), max_bound))); + } + + RawType raw() const { return i_; } + RawType &raw() { return i_; } + + private: + RawType i_; +}; + +// Part 3: implementation of arithmetic operators for the +// FixedPoint class, and a few related functions. + +// A FixedPoint multiplication is just a +// SaturatingRoundingDoublingHighMul operation on the underlying +// raw integer values. The IntegerBits simply add up, as is obvious +// from the fact that the range is [-2^IntegerBits, 2^IntegerBits). +template +FixedPoint operator*(FixedPoint a, + FixedPoint b) { + FixedPoint c; + c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw()); + return c; +} + +// Tweaking IntegerBits gives exact multiplication by a power of two. +template +FixedPoint ExactMulByPot(FixedPoint a) { + FixedPoint c; + c.raw() = a.raw(); + return c; +} + +// If we want to leave IntegerBits fixed, then multiplication +// by a power of two has to be saturating/rounding, not exact anymore. +template +FixedPoint SaturatingRoundingMultiplyByPOT(FixedPoint a) { + return FixedPoint::FromRaw(SaturatingRoundingMultiplyByPOT(a.raw())); +} + +// Generic arithmetic operators. + +#define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \ + template \ + FixedPoint FuncName(FixedPoint a) { \ + return FixedPoint::FromRaw(ImplFuncName(a.raw())); \ + } + +#define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \ + template \ + FixedPoint FuncName(FixedPoint a, \ + FixedPoint b) { \ + return FixedPoint::FromRaw(ImplFuncName(a.raw(), b.raw())); \ + } + +MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg) +MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot) +MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add) +MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub) +MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd) +MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor) +MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr) +MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum) + +#undef MAKE_FIXEDPOINT_UNARY_FUNC +#undef MAKE_FIXEDPOINT_BINARY_FUNC + +#define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \ + template \ + tRawType FuncName(FixedPoint a) { \ + return FuncName(a.raw()); \ + } + +#define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \ + template \ + tRawType FuncName(FixedPoint a, FixedPoint b) { \ + return FuncName(a.raw(), b.raw()); \ + } + +MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero) +MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero) +MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual) +MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual) +MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan) +MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual) +MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan) +MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual) + +#undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW +#undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW + +template +FixedPoint SelectUsingMask(tRawType if_mask, FixedPoint then_val, + FixedPoint else_val) { + return FixedPoint::FromRaw(SelectUsingMask(if_mask, then_val.raw(), else_val.raw())); +} + +template +bool operator==(FixedPoint a, FixedPoint b) { + return All(MaskIfEqual(a.raw(), b.raw())); +} + +template +bool operator!=(FixedPoint a, FixedPoint b) { + return !(a == b); +} + +template +FixedPoint SaturatingAdd(FixedPoint a, + FixedPoint b) { + return FixedPoint::FromRaw(SaturatingAdd(a.raw(), b.raw())); +} + +template +FixedPoint AddSaturatingIf16Bit(FixedPoint a, + FixedPoint b) { + return FixedPoint::FromRaw(AddSaturatingIf16Bit(a.raw(), b.raw())); +} + +// Conversion to floating-point. +template +double ToDouble(FixedPoint x) { + static_assert(FixedPointRawTypeTraits::kLanes == 1, "not applicable to SIMD types"); + typedef FixedPoint F; + return x.raw() / (double)(1ll << F::kFractionalBits); +} + +// Rescale changes the number of IntegerBits and updates the underlying +// raw integer value accordingly. +template +FixedPoint Rescale(FixedPoint x) { + static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst; + FixedPoint result; + result.raw() = SaturatingRoundingMultiplyByPOT(x.raw()); + return result; +} + +// CheckedFixedPointConstant allows to specify fixed-point constants +// initialized as real numbers, in a way that does not compile floating-point +// arithmetic in production code, yet still checks agreement with the +// floating-point expressions when asserts are enabled. +// +// The raw integer value provided is always a int32, encoding a 32-bit +// fixed-point value, regardless of the actual Scalar type. This allows +// writing generic code that applies just as well to the 32-bit and 16-bit +// cases. In the 16-bit case, the raw integer value is internally +// rounding-shifted by 16 bits to the right. +template +inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(std::int32_t int32_value) { + typedef typename FixedPointType::ScalarRawType ScalarRawType; + static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType); + return (ScalarRawType)(RoundingDivideByPOT(int32_value, 32 - ScalarTypeBits)); +} + +// Implementation of exponential function. + +// Returns -tanh(x) for x < 0. +template +FixedPoint neg_tanh_on_negative_values(FixedPoint a) { + return one_minus_x_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(ExactMulByPot<1>(a))); +} + +// Returns tanh(x) for any x. +template +FixedPoint tanh(FixedPoint a) { + typedef FixedPoint InputF; + typedef FixedPoint ResultF; + tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero()); + tRawType mask_if_zero = MaskIfZero(a); + InputF n = SelectUsingMask(mask_if_negative, a, -a); + ResultF t = neg_tanh_on_negative_values(n); + return SelectUsingMask(mask_if_zero, ResultF::Zero(), SelectUsingMask(mask_if_negative, -t, t)); +} + +// Implementation of logistic function. + +// Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0. +template +FixedPoint logistic_on_positive_values(FixedPoint a) { + return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a)); +} + +#ifdef ENABLE_NEON +inline int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent) { + const int32x4_t shift_vec = vdupq_n_s32(-exponent); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31); + const int32x4_t fixed_up_x = vqaddq_s32(x, fixup); + return vrshlq_s32(fixed_up_x, shift_vec); +} + +inline int32x4_t SaturatingRoundingDoublingHighMulInt32x4(int32x4_t a, int32x4_t b) { + return vqrdmulhq_s32(a, b); +} +#endif + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_QUANTIZATION_FIXED_POINT_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.cc new file mode 100644 index 00000000000..5add73fa8db --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/quantization/quantize.h" + +const uint64_t dSignMask = 1ull << 63; +const uint64_t dExponentMask = 0x7ffull << 52; +const uint64_t dFractionMask = (1ull << 52) - 1; +const int dExponentBias = 1022; +const int dMantissaBits = 52; +const int dInfiniteExponent = 0x7ff; +const double dNormalizer = 0x1p54; +const int dNormalizerBias = 54; +const int iMantissaBits = 31; + +void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift) { + if (quantized_multiplier == nullptr || shift == nullptr) { + return; + } + // we split a floating number into two parts: exponent and fraction + // since fraction is stored as int32, only 31 bits of mantissa is remained + union { + double d; + uint64_t ul; + } dul; + dul.d = double_multiplier; + if (!(dul.ul & (~dSignMask))) { + // multiplier is 0 + *quantized_multiplier = 0; + *shift = 0; + return; + } + int exponent = (int) ((dul.ul & dExponentMask) >> dMantissaBits); + if (exponent == dInfiniteExponent) { + // multiplier is inf or NaN + *shift = 0; + if (!(dul.ul & dFractionMask)) { + // inf + *quantized_multiplier = (dul.ul & dSignMask) ? INT_MIN : INT_MAX; + } else { + // NaN + *quantized_multiplier = 0; + } + return; + } + if (exponent == 0) { + // multiplier is a subnormal number + dul.d *= dNormalizer; + exponent = (int) ((dul.ul & dExponentMask) >> dMantissaBits); + *shift = exponent - dExponentBias - dNormalizerBias; + } else { + *shift = exponent - dExponentBias; + } + uint64_t fraction = dul.ul & dFractionMask; + fraction += (1ull << dMantissaBits); + uint64_t rounded = ((fraction >> (dMantissaBits - iMantissaBits)) + 1ull) >> 1; + // we get 31 rounded bits now + if (rounded == (1ull << iMantissaBits)) { + // rounding may cause a carry + rounded >>= 1; + ++*shift; + } + *quantized_multiplier = (dul.ul & dSignMask) ? (-(int32_t)(rounded)) : (int32_t)(rounded); +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h b/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h new file mode 100644 index 00000000000..9a6f474b727 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_QUANTIZATION_QUANTIZE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_QUANTIZATION_QUANTIZE_H_ + +#include +#include +#include +#include + +struct QuantArg { + double scale_; + int32_t zp_; +}; + +struct ConvQuantArg { + QuantArg **quant_args_; + double *real_multiplier_; + int32_t *left_shift_; + int32_t *right_shift_; + int32_t *quant_multiplier_; + int32_t *out_act_min_; + int32_t *out_act_max_; +}; + +struct ConcatQuantArg { + int *input_sizes_; + int output_size_; + int **input_shapes_; + int *output_shape_; + size_t input_num_; + size_t output_dim_; + QuantArg *in_quant_args_; + QuantArg out_quant_args_; +}; + +struct FcQuantArg { + QuantArg input; + QuantArg weight; + QuantArg output; + int32_t out_act_min; + int32_t out_act_max; + int32_t output_shift; + int32_t quant_multiplier; +}; + +struct MulQuantArg { + QuantArg in_quant_args_[2]; + QuantArg out_quant_arg_; + int output_multiplier_; + int output_activation_min_; + int output_activation_max_; + int shift_left_; + int shift_right_; +}; + +void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift); + +inline void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, + int *right_shift) { + if (quantized_multiplier == nullptr || right_shift == nullptr) { + return; + } + int shift; + QuantizeMultiplier(double_multiplier, quantized_multiplier, &shift); + *right_shift = -shift; +} + +inline void QuantizeRoundParameter(double double_multiplier, int32_t *quantized_multiplier, int *left_shift, + int *right_shift) { + int shift; + QuantizeMultiplierSmallerThanOne(double_multiplier, quantized_multiplier, &shift); + shift = -shift; + if (shift < 0) { + *left_shift = 0; + *right_shift = shift; + } else { + *left_shift = shift; + *right_shift = 0; + } +} + +inline uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } + +inline int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } + +inline void CalculateActivationRangeQuantized(float fmax, float fmin, float scale, int zero_point, int *imax, + int *imin) { + int8_t qmin = (int8_t)CHAR_MIN; + int8_t qmax = (int8_t)CHAR_MAX; + int8_t qfmin = QuantizeToInt8(fmin, scale, zero_point); + int8_t qfmax = QuantizeToInt8(fmax, scale, zero_point); + *imin = qmin < qfmin ? qmin : qfmin; + *imax = qmax > qfmax ? qmax : qfmax; +} +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_QUANTIZATION_QUANTIZE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/reshape.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/reshape.cc new file mode 100644 index 00000000000..f0eb2a90b90 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/reshape.cc @@ -0,0 +1,22 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/reshape.h" +#include + +void Reshape(void *input_ptr, void *output_ptr, size_t data_size) { memcpy(output_ptr, input_ptr, data_size); } + + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/reshape.h b/mindspore/lite/src/runtime/kernel/arm/opclib/reshape.h new file mode 100644 index 00000000000..48d04fbbf5c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/reshape.h @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_RESHAHPE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_RESHAHPE_H_ +#include "src/runtime/kernel/arm/opclib/op_base.h" + +void Reshape(void *input_ptr, void *output_ptr, size_t data_size); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_RESHAHPE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/reshape_parameter.h b/mindspore/lite/src/runtime/kernel/arm/opclib/reshape_parameter.h new file mode 100644 index 00000000000..5bb484daad8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/reshape_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_RESHAHPE_PARAMETER_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_RESHAHPE_PARAMETER_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct ReshapeParameter { + OpParameter op_parameter_; + int thread_count_; +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_RESHAHPE_PARAMETER_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/resize.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/resize.cc new file mode 100644 index 00000000000..9265a9618b0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/resize.cc @@ -0,0 +1,136 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "src/runtime/kernel/arm/opclib/resize.h" +#include "src/runtime/kernel/arm/opclib/offset_utils.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" + +int ResizeBilinear(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, + bool align_corners, int tid, int thread_num) { + if (input_data == nullptr || output_data == nullptr || input_shape == nullptr || output_shape == nullptr) { + return OPCLIB_NULL_PTR; + } + // nhwc (memory layout is nc4hw4) + int n = input_shape[0]; + int in_h = input_shape[1]; + int in_w = input_shape[2]; + int channel = input_shape[3]; + int c4 = UP_DIV(channel, C4NUM); + + int new_height = output_shape[1]; + int new_width = output_shape[2]; + float height_scale = (float)(in_h) / new_height; + float width_scale = (float)(in_w) / new_width; + if (align_corners && new_height > 1) { + height_scale = (float)(in_h - 1) / (new_height - 1); + } + if (align_corners && new_width > 1) { + width_scale = (float)(in_w - 1) / (new_width - 1); + } + + int o[5]; // n c4 h w 4 + for (o[0] = 0; o[0] < n; o[0]++) { + for (o[1] = tid; o[1] < c4; o[1] += thread_num) { + for (o[2] = 0; o[2] < new_height; o[2]++) { + float actual_y = (float)(o[2]) * height_scale; + int y_left = (int)(floor(actual_y)); + int y_right = y_left + 1 < in_h ? (y_left + 1) : (in_h - 1); + float y_right_weight = actual_y - (float)(y_left); + float y_left_weight = 1.0 - y_right_weight; + for (o[3] = 0; o[3] < new_width; o[3]++) { + float actual_x = (float)(o[3]) * width_scale; + int x_left = (int)(floor(actual_x)); + int x_right = x_left + 1 < in_w ? (x_left + 1) : (in_w - 1); + float x_right_weight = actual_x - (float)(x_left); + float x_left_weight = 1.0 - x_right_weight; + + auto input_base_offset = (((o[0] * c4 + o[1]) * in_h + y_left) * in_w + x_left) * C4NUM; + auto output_base_offset = (((o[0] * c4 + o[1]) * new_height + o[2]) * new_width + o[3]) * C4NUM; + int in_offset_1_0 = (y_right - y_left) * in_w * C4NUM; + int in_offset_0_1 = (x_right - x_left) * C4NUM; +#ifdef ENABLE_NEON + float32x4_t x_l_weight = vdupq_n_f32(x_left_weight); + float32x4_t x_r_weight = vdupq_n_f32(x_right_weight); + float32x4_t y_l_weight = vdupq_n_f32(y_left_weight); + float32x4_t y_r_weight = vdupq_n_f32(y_right_weight); + + float32x4_t input_yl_xl = vld1q_f32(input_data + input_base_offset); + float32x4_t input_yr_xl = vld1q_f32(input_data + input_base_offset + in_offset_1_0); + float32x4_t input_yl_xr = vld1q_f32(input_data + input_base_offset + in_offset_0_1); + float32x4_t input_yr_xr = vld1q_f32(input_data + input_base_offset + in_offset_0_1 + in_offset_1_0); + + float32x4_t interp_value = vdupq_n_f32(0.0); + float32x4_t interp_value_tmp = vmulq_f32(input_yl_xl, y_l_weight); + interp_value_tmp = vmulq_f32(interp_value_tmp, x_l_weight); + interp_value = vaddq_f32(interp_value, interp_value_tmp); + + interp_value_tmp = vmulq_f32(input_yr_xl, y_r_weight); + interp_value_tmp = vmulq_f32(interp_value_tmp, x_l_weight); + interp_value = vaddq_f32(interp_value, interp_value_tmp); + + interp_value_tmp = vmulq_f32(input_yl_xr, y_l_weight); + interp_value_tmp = vmulq_f32(interp_value_tmp, x_r_weight); + interp_value = vaddq_f32(interp_value, interp_value_tmp); + + interp_value_tmp = vmulq_f32(input_yr_xr, y_r_weight); + interp_value_tmp = vmulq_f32(interp_value_tmp, x_r_weight); + interp_value = vaddq_f32(interp_value, interp_value_tmp); + vst1q_f32(output_base_offset + output_data, interp_value); +#else + // 4 continuous data in a group; + for (o[4] = 0; o[4] < C4NUM; o[4]++) { + auto in_offset = input_base_offset + o[4]; + auto output_offset = output_base_offset + o[4]; + float interp_value = + input_data[in_offset] * y_left_weight * x_left_weight + + input_data[in_offset + in_offset_1_0] * y_right_weight * x_left_weight + + input_data[in_offset + in_offset_0_1] * y_left_weight * x_right_weight + + input_data[in_offset + in_offset_0_1 + in_offset_1_0] * y_right_weight * x_right_weight; + output_data[output_offset] = interp_value; + } +#endif + } + } + } + } + return OPCLIB_OK; +} + +int ResizeNearestNeighbor(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, + int tid, int thread_num) { + int batch, y, x, c; + c = input_shape[3]; + + float height_scale = (float)(input_shape[1]) / (float)(output_shape[1]); + float width_scale = (float)(input_shape[2]) / (float)(output_shape[2]); + + for (batch = 0; batch < output_shape[0]; batch++) { + for (y = tid; y < output_shape[1]; y += thread_num) { + int actual_y = (int)(floor((float)(y) * height_scale)); + int input_y = actual_y < input_shape[1] ? actual_y : input_shape[1] - 1; + for (x = 0; x < output_shape[2]; x++) { + int actual_x = (int)(floor((float)(x) * width_scale)); + int input_x = actual_x < input_shape[2] ? actual_x : input_shape[2] - 1; + int in_offset = offset(input_shape, batch, input_y, input_x, 0); + int out_offset = offset(output_shape, batch, y, x, 0); + memcpy(output_data + out_offset, input_data + in_offset, c * sizeof(float)); + } + } + } + + return OPCLIB_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/resize.h b/mindspore/lite/src/runtime/kernel/arm/opclib/resize.h new file mode 100644 index 00000000000..7a7033d12af --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/resize.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_RESIZE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_RESIZE_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "schema/ops_generated.h" +#include "src/runtime/kernel/arm/opclib/errorcode.h" + +using mindspore::schema::ResizeMethod; + +struct ResizeParameter { + OpParameter op_parameter_; + ResizeMethod method_; + int64_t new_height_; + int64_t new_width_; + bool align_corners_; + bool preserve_aspect_ratio_; +}; + +int ResizeBilinear(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, + bool align_corners, int tid, int thread_num); + +int ResizeNearestNeighbor(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, + int tid, int thread_num); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_RESIZE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/reverse_sequence.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/reverse_sequence.cc new file mode 100644 index 00000000000..051a17a88a5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/reverse_sequence.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/reverse_sequence.h" +#include +#include "src/runtime/kernel/arm/opclib/arithmetic_common.h" + +void ReverseSequence(float *input0, int *input1, float *output, ReverseSequenceParameter *para) { + (void)memcpy(output, input0, para->total_data_size_); + ComputeStrides(para->input_shape0_, para->input_stride_, para->ndim_); + ComputeStrides(para->output_shape_, para->output_stride_, para->ndim_); + for (int i = 0; i < para->outer_count_; ++i) { + auto in = input0 + i * para->outer_stride_; + auto out = output + i * para->outer_stride_; + for (int batch = 0; batch < para->input_shape0_[para->batch_axis_]; batch++) { + auto in_batch = in + batch * para->input_stride_[para->batch_axis_]; + auto out_batch = out + batch * para->output_stride_[para->batch_axis_]; + for (int n = 0; n < input1[batch]; ++n) { + auto in_seq = in_batch + (input1[batch] - 1 - n) * para->input_stride_[para->seq_axis_]; + auto out_seq = out_batch + n * para->output_stride_[para->seq_axis_]; + for (int j = 0; j < para->inner_count_; ++j) { + (void)memcpy(out_seq + j * para->inner_stride_, in_seq + j * para->inner_stride_, para->copy_byte_size_); + } + } + } + } +} + + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/reverse_sequence.h b/mindspore/lite/src/runtime/kernel/arm/opclib/reverse_sequence.h new file mode 100644 index 00000000000..4307a09a3cc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/reverse_sequence.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_REVERSE_SEQUENCE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_REVERSE_SEQUENCE_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct ReverseSequenceParameter { + OpParameter op_parameter_; + int ndim_; + int input_shape0_[5]; + int output_shape_[5]; + int input_stride_[5]; + int output_stride_[5]; + int seq_axis_; + int batch_axis_; + int outer_count_; + int outer_stride_; + int inner_count_; + int inner_stride_; + int copy_byte_size_; + int total_data_size_; +}; + +void ReverseSequence(float *input0, int *input1, float *output, ReverseSequenceParameter *para); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_REVERSE_SEQUENCE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/scale.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/scale.cc new file mode 100644 index 00000000000..df2700c09da --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/scale.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/scale.h" +#include "src/runtime/kernel/arm/opclib/errorcode.h" + +int DoScale(float *in_data, float *out_data, float *scale, float *offset, int units_offset, int num_unit, + ScaleParameter *scale_param) { + if (in_data == nullptr || out_data == nullptr || scale == nullptr || offset == nullptr || scale_param == nullptr) { + return OPCLIB_ERR; + } + + int in_stride_j = units_offset * scale_param->in_stride_; + for (int j = units_offset; j < units_offset + num_unit; j++) { + int channel = j % scale_param->channel_; + for (int k = 0; k < scale_param->in_stride_; k++) { + out_data[in_stride_j + k] = in_data[in_stride_j + k] * scale[channel] + offset[channel]; + } + in_stride_j = in_stride_j + scale_param->in_stride_; + } + return OPCLIB_OK; +} + +int DoScale(float *in_data, float *out_data, float *scale, int units_offset, int num_unit, + ScaleParameter *scale_param) { + if (in_data == nullptr || out_data == nullptr || scale == nullptr || scale_param == nullptr) { + return OPCLIB_ERR; + } + + int in_stride_j = units_offset * scale_param->in_stride_; + for (int j = units_offset; j < units_offset + num_unit; j++) { + int channel = j % scale_param->channel_; + for (int k = 0; k < scale_param->in_stride_; k++) { + out_data[in_stride_j + k] = in_data[in_stride_j + k] * scale[channel]; + } + in_stride_j = in_stride_j + scale_param->in_stride_; + } + return OPCLIB_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/scale.h b/mindspore/lite/src/runtime/kernel/arm/opclib/scale.h new file mode 100644 index 00000000000..077fb5ae57f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/scale.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SCALE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SCALE_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct ScaleParameter { + OpParameter op_parameter_; + int out_count_; + int channel_; + int in_stride_; + int axis_; + int num_axis_; +}; + +int DoScale(float *in_data, float *out_data, float *scale, float *offset, int units_offset, int num_unit, + ScaleParameter *scale_param); +int DoScale(float *in_data, float *out_data, float *scale, int units_offset, int num_unit, ScaleParameter *scale_param); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SCALE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/scatter_nd.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/scatter_nd.cc new file mode 100644 index 00000000000..26a205338a8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/scatter_nd.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/scatter_nd.h" +#include +#include +#include "src/runtime/kernel/arm/opclib/errorcode.h" + +int DoScatterND(float *output_ptr, float *update, int *output_unit_offsets, int unit_size, int num_units) { + if (output_ptr == nullptr || update == nullptr || output_unit_offsets == nullptr || unit_size <= 0 || num_units < 0) { + return OPCLIB_ERR; + } + for (int i = 0; i < num_units; i++) { + (void)memcpy(output_ptr + output_unit_offsets[i], update + unit_size * i, unit_size * sizeof(float)); + } + return OPCLIB_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/scatter_nd.h b/mindspore/lite/src/runtime/kernel/arm/opclib/scatter_nd.h new file mode 100644 index 00000000000..c3f8a4d6ab1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/scatter_nd.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SCATTER_ND_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SCATTER_ND_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct ScatterNDParameter { + OpParameter op_parameter_; +}; + +int DoScatterND(float *output_ptr, float *update, int *output_unit_offsets, int unit_size, int num_units); +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SCATTER_ND_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/shape.h b/mindspore/lite/src/runtime/kernel/arm/opclib/shape.h new file mode 100644 index 00000000000..ad63aa76173 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/shape.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ARM_OPCLIB_SHAPE_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_ARM_OPCLIB_SHAPE_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct ShapeParameter { + OpParameter op_parameter_; +}; + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_ARM_OPCLIB_SHAPE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/sparse_to_dense.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/sparse_to_dense.cc new file mode 100644 index 00000000000..372c38e4bc6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/sparse_to_dense.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/opclib/sparse_to_dense.h" + +void SparseToDense(int *input, int *output_shape_, float *snum, float *dnum, int sp_num, float *output, + SparseToDenseParameter *s2d_param_, int task_id) { + int m; + for (int i = task_id; i < output_shape_[0]; i += s2d_param_->op_parameter_.thread_num_) { + for (int j = 0; j < output_shape_[1]; j++) { + m = i * output_shape_[1] + j; + output[m] = dnum[0]; + } + } + + for (int j = 0; j < sp_num; j++) { + int temp = j * 2; + int temp1 = j * 2 + 1; + int tempout1 = input[temp] * output_shape_[1] + input[temp1]; + output[tempout1] = snum[j]; + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/sparse_to_dense.h b/mindspore/lite/src/runtime/kernel/arm/opclib/sparse_to_dense.h new file mode 100644 index 00000000000..b04e0efa6de --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/sparse_to_dense.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SPARSETODENSE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SPARSETODENSE_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct SparseToDenseParameter { + OpParameter op_parameter_; + int thread_num_; + int count_ = 0; +}; + +void SparseToDense(int *input, int *output_shape_, float *snum, float *dnum, int sp_num, float *output, + SparseToDenseParameter *s2d_param_, int task_id); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SPARSETODENCE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/split.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/split.cc new file mode 100644 index 00000000000..2a51702d419 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/split.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/split.h" +#include +#include "src/runtime/kernel/arm/opclib/errorcode.h" + +int DoSplit(float *in_data, float **out_data, const int *input_shape, int offset, int num_unit, + SplitParameter *split_param) { + if (in_data == nullptr || out_data == nullptr) { + return OPCLIB_ERR; + } + int num_split = split_param->num_split_; + int *split_sizes = split_param->split_sizes_; + int *strides = split_param->strides_; + int split_dim = split_param->split_dim_; + int in_stride = strides[split_dim]; + + float *src; + int size_float = (int)(sizeof(float)); + int in_stride_bytes = in_stride * size_float; + + int split_which; + int split_times; + int stride_per_split = in_stride * input_shape[split_dim]; + + split_which = offset % num_split; + split_times = offset / num_split; + src = in_data + split_times * stride_per_split; + + for (int i = 0; i < split_which; i++) { + src += split_sizes[i] * in_stride; + } + + for (int i = offset; i < offset + num_unit; i++) { + split_which = i % num_split; + split_times = i / num_split; + int split_size = split_sizes[split_which]; + float *dst = out_data[split_which] + split_times * in_stride * split_size; + (void)memcpy(dst, src, split_size * in_stride_bytes); + src += split_size * in_stride; + } + + return OPCLIB_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/split.h b/mindspore/lite/src/runtime/kernel/arm/opclib/split.h new file mode 100644 index 00000000000..3297a6afae2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/split.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SPLIT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SPLIT_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct SplitParameter { + OpParameter op_parameter_; + int num_split_; + int split_sizes_[20] = {0}; + int strides_[8]; + int split_dim_; + int n_dims_; + int split_count_; +}; + +int DoSplit(float *in_data, float **out_data, const int *input_shape, int offset, int num_unit, + SplitParameter *split_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SPLIT_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/squeeze.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/squeeze.cc new file mode 100644 index 00000000000..68307088a72 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/squeeze.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/squeeze.h" +#include + +int DoSqueeze(float *in_data, float *out_data, size_t data_size) { + if (in_data == nullptr || out_data == nullptr) { + return -1; + } + (void)memcpy(out_data, in_data, data_size); + return 0; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/squeeze.h b/mindspore/lite/src/runtime/kernel/arm/opclib/squeeze.h new file mode 100644 index 00000000000..4c7fdddb1d4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/squeeze.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SQUEEZE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SQUEEZE_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct SqueezeParameter { + OpParameter op_parameter_; + int axes_[8]; +}; + +int DoSqueeze(float *input_ptr, float *output_ptr, size_t data_size); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_SQUEEZE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/strassen_matmul.h b/mindspore/lite/src/runtime/kernel/arm/opclib/strassen_matmul.h new file mode 100644 index 00000000000..91414c0e424 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/strassen_matmul.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_STRASSEN_MATMUL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_STRASSEN_MATMUL_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +/* hw*inc4 X inc4*oc4 */ +struct StrassenMatMulParameter { + OpParameter op_parameter; + int row_{}; /* h * w */ + int col_{}; /* oc4 / 4 */ + int deep_{}; /* inc4 / 4 */ + int a_stride_{}; /* h * w * 4 */ + int b_stride_{}; /* inc4 * 4 */ + int c_stride_{}; /* h * w * 4 */ +}; + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_STRASSEN_MATMUL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/tile.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/tile.cc new file mode 100644 index 00000000000..e8129a77143 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/tile.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/tile.h" +#include + +void CopyData(float *input_data, float *output_data, size_t size, size_t multiple) { + float *out_data = output_data; + for (size_t i = 0; i < multiple; ++i) { + (void)memcpy(out_data, input_data, size * sizeof(float)); + out_data += size; + } +} + +int TileOneDimension(float *input_data, float *output_data, size_t dim, TileParameter *parameter) { + size_t src_dim_size = parameter->in_shape_[dim]; + if (dim == parameter->in_dim_ - 1) { + CopyData(input_data, output_data, src_dim_size, parameter->multiples_[dim]); + return 0; + } + for (size_t i = 0; i < src_dim_size; ++i) { + for (size_t j = 0; j < parameter->multiples_[dim]; ++j) { + size_t in_pos = parameter->in_strides_[dim] * i; + size_t out_pos = parameter->out_strides_[dim] * (i + j * src_dim_size); + TileOneDimension(input_data + in_pos, output_data + out_pos, dim + 1, parameter); + } + } + return 0; +} + +void Tile(float *input_data, float *output_data, TileParameter *parameter) { + TileOneDimension(input_data, output_data, 0, parameter); +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/tile.h b/mindspore/lite/src/runtime/kernel/arm/opclib/tile.h new file mode 100644 index 00000000000..5fd736642ce --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/tile.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_TILE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_TILE_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct TileParameter { + OpParameter op_parameter_; + int in_dim_; + int in_shape_[5]; + int out_shape_[5]; + int multiples_[5]; + int in_strides_[5]; + int out_strides_[5]; +}; + +void Tile(float *input_data, float *output_data, TileParameter *parameter); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_TILE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/topk.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/topk.cc new file mode 100644 index 00000000000..30da41e3ff1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/topk.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/topk.h" + +int DescendCmp(const void *a, const void *b) { + return ((const Node *)b)->element - ((const Node *)a)->element; +} + +int AscendCmp(const void *a, const void *b) { + return ((const Node *)a)->element - ((const Node *)b)->element; +} + +void Topk(float *input_data, float *output_data, float *output_index, TopkParameter *parameter) { + int last_dim_size = parameter->last_dim_size_; + int loop_num = parameter->loop_num_; + int k = parameter->k_; + Node *top_map = parameter->topk_node_list_; + + float *cur_input_data = input_data; + float *cur_output_data = output_data; + float *cur_output_index = output_index; + for (int i = 0; i < loop_num; i++) { + for (int j = 0; j < last_dim_size; j++) { + top_map[j].element = *(cur_input_data + j); + top_map[j].index = j; + } + if (parameter->sorted_) { + qsort(top_map, last_dim_size, sizeof(top_map[0]), DescendCmp); + } else { + qsort(top_map, last_dim_size, sizeof(top_map[0]), AscendCmp); + } + for (int m = 0; m < k; m++) { + cur_output_data[m] = top_map[m].element; + cur_output_index[m] = top_map[m].index; + } + cur_input_data += last_dim_size; + cur_output_data += k; + cur_output_index += k; + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/topk.h b/mindspore/lite/src/runtime/kernel/arm/opclib/topk.h new file mode 100644 index 00000000000..3a038aa5925 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/topk.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_TOPK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_TOPK_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct Node { + float element; + float index; +}; + +struct TopkParameter { + OpParameter op_parameter_; + int last_dim_size_; + int loop_num_; + int k_; + bool sorted_; + Node *topk_node_list_; +}; + +void Topk(float *input_data, float *output_data, float *output_index, TopkParameter *parameter); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_TOPK_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/transpose.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/transpose.cc new file mode 100644 index 00000000000..b9984aee068 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/transpose.cc @@ -0,0 +1,125 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/transpose.h" +#include +#include "src/runtime/kernel/arm/opclib/errorcode.h" + +void TransposeDim2(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + for (int i = 0; i < output0; i++) { + int out_stride0_i = i * output1; + int stride0_i = i * 1 * stride0; + for (int j = 0; j < output1; j++) { + out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; + } + } +} + +void TransposeDim3(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + for (int i = 0; i < output0; i++) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; j++) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; k++) { + out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; + } + } + } +} + +void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + + for (int i = 0; i < output0; i++) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; j++) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; k++) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; m++) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = + in_data[stride0_i + stride1_j + stride2_k + m * stride3]; + } + } + } + } +} + +int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_shape, + TransposeParameter *transpose_param) { + if (in_data == nullptr || out_data == nullptr) { + return OPCLIB_ERR; + } + int *perm = transpose_param->perm_; + int *strides = transpose_param->strides_; + int *out_strides = transpose_param->out_strides_; + int data_size = transpose_param->data_size_; + int num_axes = transpose_param->num_axes_; + + if (num_axes < 2 || num_axes > 4) { + return OPCLIB_ERR; + } + + // check if transpose is needed + bool needTranspose = false; + for (int i = 1; i < num_axes; i++) { + if (perm[i] - perm[i - 1] != 1) { + needTranspose = true; + break; + } + } + + if (!needTranspose) { + (void)memcpy(out_data, in_data, data_size); + return OPCLIB_OK; + } + if (num_axes == 2) { + TransposeDim2(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 3) { + TransposeDim3(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 4) { + TransposeDim4(in_data, out_data, strides, out_strides, perm, output_shape); + } + return OPCLIB_OK; +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/transpose.h b/mindspore/lite/src/runtime/kernel/arm/opclib/transpose.h new file mode 100644 index 00000000000..202ed4f4574 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/transpose.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_TRANSPOSE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_TRANSPOSE_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct TransposeParameter { + OpParameter op_parameter_; + int perm_[8]; + bool conjugate_; + int num_axes_; + int strides_[8]; + int out_strides_[8]; + int data_size_; +}; + +int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_shape, + TransposeParameter *transpose_param); +void TransposeDim2(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape); +void TransposeDim3(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape); +void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_TRANSPOSE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/unique.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/unique.cc new file mode 100644 index 00000000000..1917a050e90 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/unique.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/unique.h" + +int Find(float *array, int len, float target) { + for (int i = 0; i < len; ++i) { + if (array[i] == target) { + return i; + } + } + return -1; +} + +void Unique(float *input, int input_len, float *output0, int *output0_len, int *output1) { + output0_len = 0; + for (int i = 0; i < input_len; i++) { + int idx = Find(output0, *output0_len, input[i]); + if (idx != -1) { + *output1++ = idx; + } else { + output0[(*output0_len)++] = input[i]; + *output1++ = *output0_len - 1; + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/unique.h b/mindspore/lite/src/runtime/kernel/arm/opclib/unique.h new file mode 100644 index 00000000000..65fe60d2292 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/unique.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_UNIQUE_H +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_UNIQUE_H + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct UniqueParameter { + OpParameter op_parameter_; +}; + +void Unique(float *input, int input_len, float *output0, int *output0_len, int *output1); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_UNIQUE_H + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/unstack.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/unstack.cc new file mode 100644 index 00000000000..98fdd22f936 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/unstack.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/unstack.h" +#include + +void Unistack(float *input, float **output, UnstackParameter *para) { + for (int j = 0; j < para->num_; j++) { + float *out_addr = output[j]; + int out_offset = 0; + for (int i = 0; i < para->pre_dims_; i++) { + int in_offset = i * para->axis_dim_ * para->after_dims_ + j * para->after_dims_; + (void)memcpy(out_addr + out_offset, input + in_offset, para->after_dims_ * sizeof(float)); + out_offset += para->after_dims_; + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/unstack.h b/mindspore/lite/src/runtime/kernel/arm/opclib/unstack.h new file mode 100644 index 00000000000..e4fb3577831 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/unstack.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_UNSTACK_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_UNSTACK_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct UnstackParameter { + OpParameter op_parameter_; + int num_; + int axis_; + int pre_dims_; + int axis_dim_; + int after_dims_; +}; + +void Unistack(float *input, float **output, UnstackParameter *para); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_UNSTACK_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/where.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/where.cc new file mode 100644 index 00000000000..bd7f302dcec --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/where.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/opclib/where.h" + +void Where(bool *input, float *input1, float *input2, float *output, WhereParameter *where_param_, int task_id) { + for (int i = task_id; i < where_param_->number_; i += where_param_->op_parameter_.thread_num_) { + if (input[where_param_->num_ > 1 ? i : 0] == true) { + output[i] = input1[where_param_->num1_ > 1 ? i : 0]; + } else { + output[i] = input2[where_param_->num2_ > 1 ? i : 0]; + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/where.h b/mindspore/lite/src/runtime/kernel/arm/opclib/where.h new file mode 100644 index 00000000000..f0dda220a4a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/where.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_WHERE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_WHERE_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct WhereParameter { + OpParameter op_parameter_; + int num_; + int num1_; + int num2_; + int number_; + int thread_num_; +}; + +void Where(bool *input, float *input1, float *input2, float *output, WhereParameter *where_param_, int task_id); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_WHERE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_transform.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_transform.cc new file mode 100644 index 00000000000..162b58f7942 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_transform.cc @@ -0,0 +1,1883 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/winograd_transform.h" + +// fp32 conv winograd +void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, ConvParameter *conv_param, + InputTransformUnitFunc input_trans_func) { + int input_unit = conv_param->input_unit_; + int output_unit = conv_param->output_unit_; + int in_channel = conv_param->input_channel_; + int ic4 = UP_DIV(in_channel, C4NUM); + int pad_h = conv_param->pad_h_; + int pad_w = conv_param->pad_w_; + int input_h = conv_param->input_h_; + int input_w = conv_param->input_w_; + + for (int c = 0; c < cal_num; c++) { // actual tiled number + int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w; + int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h; + int interval_x_s = src_x_s > 0 ? 0 : -src_x_s; + int interval_y_s = src_y_s > 0 ? 0 : -src_y_s; + int src_x_e = src_x_s + input_unit; + int src_y_e = src_y_s + input_unit; + int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); + int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); + + int src_plane_offset = ic4 * C4NUM * (src_y_s * input_w + src_x_s); + int dst_plane_offset = c * C4NUM; + for (int ic = 0; ic < ic4; ic++) { + // clear tmp buffer + memset(tmp_data, 0, input_unit * input_unit * C4NUM * sizeof(float)); + + // get real input block with padding + int src_ic4_offset = src_plane_offset + ic * C4NUM; + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * ic4 * C4NUM; + int dst_y_offset = interval * input_unit * C4NUM + interval_x_s * C4NUM; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * ic4 * C4NUM; + int dst_x_offset = dst_y_offset + j * C4NUM; + float *src_addr = (float *)(input_data) + src_x_offset; + float *dst_addr = tmp_data + dst_x_offset; +#ifdef ENABLE_NEON + vst1q_f32(dst_addr, vld1q_f32(src_addr)); +#else + for (int k = 0; k < C4NUM; k++) { + dst_addr[k] = src_addr[k]; + } +#endif + } + } + // input transform + int dst_ic4_offset = dst_plane_offset + ic * TILE_NUM * C4NUM; + size_t dst_step = ic4 * C4NUM * TILE_NUM; + float *trans_input_ptr = trans_input + dst_ic4_offset; + input_trans_func(tmp_data, trans_input_ptr, C4NUM, dst_step); + } + out_tile_index++; + } // cal_tile_num loop +} + +void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const float *bias_data, int cal_num, + int out_tile_index, int output_unit_num, ConvParameter *conv_param, + OutputTransformUnitFunc output_trans_func) { + int output_unit = conv_param->output_unit_; + int output_w = conv_param->output_w_; + int output_unit_block = UP_DIV(output_w, output_unit); + int output_channel = conv_param->output_channel_; + int oc4 = UP_DIV(output_channel, C4NUM); + int input_unit = conv_param->input_unit_; + + for (int i = 0; i < cal_num; i++) { + int dst_x_s = out_tile_index % output_unit_num; + int dst_y_s = out_tile_index / output_unit_num; + int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit; + int dst_tile_offset = C4NUM * output_unit * (dst_x_s + dst_y_s * output_unit_block * output_unit); + + for (int j = 0; j < oc4; j++) { + int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM; + int dst_oc4_offset = + dst_tile_offset + j * C4NUM * output_unit_block * output_unit_block * output_unit * output_unit; + const float *src_ptr = gemm_out + src_oc4_offset; + const float *bias_ptr = bias_data + j * C4NUM; + float *dst_ptr = tmp_out_data + dst_oc4_offset; + output_trans_func(src_ptr, dst_ptr, bias_ptr, C4NUM, output_unit_block * output_unit); + } + out_tile_index++; + } +} + +// fp32 conv3x3 +void Conv3x3Fp32InputUnit(const float *tmp_data, float *trans_input_data, size_t step) { +#ifdef ENABLE_ARM + float32x4_t d00 = vld1q_f32(tmp_data); + float32x4_t d01 = vld1q_f32(tmp_data + 4); + float32x4_t d02 = vld1q_f32(tmp_data + 2 * 4); + float32x4_t d03 = vld1q_f32(tmp_data + 3 * 4); + + float32x4_t d10 = vld1q_f32(tmp_data + 4 * 4); + float32x4_t d11 = vld1q_f32(tmp_data + 5 * 4); + float32x4_t d12 = vld1q_f32(tmp_data + 6 * 4); + float32x4_t d13 = vld1q_f32(tmp_data + 7 * 4); + + float32x4_t d20 = vld1q_f32(tmp_data + 8 * 4); + float32x4_t d21 = vld1q_f32(tmp_data + 9 * 4); + float32x4_t d22 = vld1q_f32(tmp_data + 10 * 4); + float32x4_t d23 = vld1q_f32(tmp_data + 11 * 4); + + float32x4_t d30 = vld1q_f32(tmp_data + 12 * 4); + float32x4_t d31 = vld1q_f32(tmp_data + 13 * 4); + float32x4_t d32 = vld1q_f32(tmp_data + 14 * 4); + float32x4_t d33 = vld1q_f32(tmp_data + 15 * 4); + + float32x4_t t00 = vsubq_f32(d00, d20); + float32x4_t t01 = vsubq_f32(d01, d21); + float32x4_t t02 = vsubq_f32(d02, d22); + float32x4_t t03 = vsubq_f32(d03, d23); + + float32x4_t t10 = vaddq_f32(d10, d20); + float32x4_t t11 = vaddq_f32(d11, d21); + float32x4_t t12 = vaddq_f32(d12, d22); + float32x4_t t13 = vaddq_f32(d13, d23); + + float32x4_t t20 = vsubq_f32(d20, d10); + float32x4_t t21 = vsubq_f32(d21, d11); + float32x4_t t22 = vsubq_f32(d22, d12); + float32x4_t t23 = vsubq_f32(d23, d13); + + float32x4_t t30 = vsubq_f32(d10, d30); + float32x4_t t31 = vsubq_f32(d11, d31); + float32x4_t t32 = vsubq_f32(d12, d32); + float32x4_t t33 = vsubq_f32(d13, d33); + + float32x4_t m00 = vsubq_f32(t00, t02); + float32x4_t m01 = vaddq_f32(t01, t02); + float32x4_t m02 = vsubq_f32(t02, t01); + float32x4_t m03 = vsubq_f32(t01, t03); + + float32x4_t m10 = vsubq_f32(t10, t12); + float32x4_t m11 = vaddq_f32(t11, t12); + float32x4_t m12 = vsubq_f32(t12, t11); + float32x4_t m13 = vsubq_f32(t11, t13); + + float32x4_t m20 = vsubq_f32(t20, t22); + float32x4_t m21 = vaddq_f32(t21, t22); + float32x4_t m22 = vsubq_f32(t22, t21); + float32x4_t m23 = vsubq_f32(t21, t23); + + float32x4_t m30 = vsubq_f32(t30, t32); + float32x4_t m31 = vaddq_f32(t31, t32); + float32x4_t m32 = vsubq_f32(t32, t31); + float32x4_t m33 = vsubq_f32(t31, t33); + + vst1q_f32(trans_input_data, m00); + vst1q_f32(trans_input_data + step, m01); + vst1q_f32(trans_input_data + 2 * step, m02); + vst1q_f32(trans_input_data + 3 * step, m03); + + vst1q_f32(trans_input_data + 4 * step, m10); + vst1q_f32(trans_input_data + 5 * step, m11); + vst1q_f32(trans_input_data + 6 * step, m12); + vst1q_f32(trans_input_data + 7 * step, m13); + + vst1q_f32(trans_input_data + 8 * step, m20); + vst1q_f32(trans_input_data + 9 * step, m21); + vst1q_f32(trans_input_data + 10 * step, m22); + vst1q_f32(trans_input_data + 11 * step, m23); + + vst1q_f32(trans_input_data + 12 * step, m30); + vst1q_f32(trans_input_data + 13 * step, m31); + vst1q_f32(trans_input_data + 14 * step, m32); + vst1q_f32(trans_input_data + 15 * step, m33); +#else + for (int i = 0; i < C4NUM; i++) { + const float *local_ptr = tmp_data + i; + float d00 = local_ptr[0]; + float d01 = (local_ptr + C4NUM)[0]; + float d02 = (local_ptr + 2 * C4NUM)[0]; + float d03 = (local_ptr + 3 * C4NUM)[0]; + + float d10 = (local_ptr + 4 * C4NUM)[0]; + float d11 = (local_ptr + 5 * C4NUM)[0]; + float d12 = (local_ptr + 6 * C4NUM)[0]; + float d13 = (local_ptr + 7 * C4NUM)[0]; + + float d20 = (local_ptr + 8 * C4NUM)[0]; + float d21 = (local_ptr + 9 * C4NUM)[0]; + float d22 = (local_ptr + 10 * C4NUM)[0]; + float d23 = (local_ptr + 11 * C4NUM)[0]; + + float d30 = (local_ptr + 12 * C4NUM)[0]; + float d31 = (local_ptr + 13 * C4NUM)[0]; + float d32 = (local_ptr + 14 * C4NUM)[0]; + float d33 = (local_ptr + 15 * C4NUM)[0]; + + float t00 = d00 - d20; + float t01 = d01 - d21; + float t02 = d02 - d22; + float t03 = d03 - d23; + + float t10 = d10 + d20; + float t11 = d11 + d21; + float t12 = d12 + d22; + float t13 = d13 + d23; + + float t20 = d20 - d10; + float t21 = d21 - d11; + float t22 = d22 - d12; + float t23 = d23 - d13; + + float t30 = d10 - d30; + float t31 = d11 - d31; + float t32 = d12 - d32; + float t33 = d13 - d33; + + float m00 = t00 - t02; + float m01 = t01 + t02; + float m02 = t02 - t01; + float m03 = t01 - t03; + + float m10 = t10 - t12; + float m11 = t11 + t12; + float m12 = t12 - t11; + float m13 = t11 - t13; + + float m20 = t20 - t22; + float m21 = t21 + t22; + float m22 = t22 - t21; + float m23 = t21 - t23; + + float m30 = t30 - t32; + float m31 = t31 + t32; + float m32 = t32 - t31; + float m33 = t31 - t33; + + (trans_input_data + i)[0] = m00; + (trans_input_data + i + step)[0] = m01; + (trans_input_data + i + 2 * step)[0] = m02; + (trans_input_data + i + 3 * step)[0] = m03; + + (trans_input_data + i + 4 * step)[0] = m10; + (trans_input_data + i + 5 * step)[0] = m11; + (trans_input_data + i + 6 * step)[0] = m12; + (trans_input_data + i + 7 * step)[0] = m13; + + (trans_input_data + i + 8 * step)[0] = m20; + (trans_input_data + i + 9 * step)[0] = m21; + (trans_input_data + i + 10 * step)[0] = m22; + (trans_input_data + i + 11 * step)[0] = m23; + + (trans_input_data + i + 12 * step)[0] = m30; + (trans_input_data + i + 13 * step)[0] = m31; + (trans_input_data + i + 14 * step)[0] = m32; + (trans_input_data + i + 15 * step)[0] = m33; + } +#endif +} + +void Conv3x3Fp32InputTransform(const float *input_data, float *trans_input, float *tmp_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param) { + // input data format : nhwc + int input_channel = conv_param->input_channel_; + int input_width = conv_param->input_w_; + int input_height = conv_param->input_h_; + int pad_w = conv_param->pad_w_; + int pad_h = conv_param->pad_h_; + int ic4 = UP_DIV(input_channel, C4NUM); + int input_unit = 4; + + for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { + int x_id = start_index + cal_id; + int origin_x = (x_id % out_w_block) * OUPUT_UNIT - pad_w; + int origin_y = (x_id / out_w_block) * OUPUT_UNIT - pad_h; + int real_x_start = origin_x > 0 ? 0 : -origin_x; + int real_x_end = (origin_x + input_unit) < input_width ? input_unit : (input_width - origin_x); + int real_y_start = origin_y > 0 ? 0 : -origin_y; + int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y); + + int src_plane_offset = ic4 * C4NUM * (origin_y * input_width + origin_x); + int dst_plane_offset = cal_id * C4NUM; + for (int ic = 0; ic < ic4; ic++) { + // clear tmp buffer + memset(tmp_data, 0, input_unit * input_unit * C4NUM * sizeof(float)); + + // get real input block with padding + int src_ic4_offset = src_plane_offset + ic * C4NUM; + for (int interval = real_y_start; interval < real_y_end; interval++) { + int src_y_offset = src_ic4_offset + (interval * input_width + real_x_start) * ic4 * C4NUM; + int dst_y_offset = interval * input_unit * C4NUM + real_x_start * C4NUM; + for (int j = 0; j < (real_x_end - real_x_start); j++) { + int src_x_offset = src_y_offset + j * ic4 * C4NUM; + int dst_x_offset = dst_y_offset + j * C4NUM; + float *src_addr = (float *)(input_data) + src_x_offset; + float *dst_addr = tmp_data + dst_x_offset; +#ifdef ENABLE_NEON + vst1q_f32(dst_addr, vld1q_f32(src_addr)); +#else + for (int k = 0; k < C4NUM; k++) { + (dst_addr + k)[0] = (src_addr + k)[0]; + } +#endif + } + } + + // input transform + int dst_ic4_offset = dst_plane_offset + ic * TILE_NUM * C4NUM; + size_t dst_step = ic4 * C4NUM * TILE_NUM; + float *trans_input_ptr = trans_input + dst_ic4_offset; + Conv3x3Fp32InputUnit(tmp_data, trans_input_ptr, dst_step); + } + } +} + +void Conv3x3Fp32FilterTransform(float *weight_data, float *trans_weight, int iC4, int output_channel, + int kernel_plane) { + int input_unit = 4; + int dst_step = iC4 * C4NUM * C8NUM; + for (int o = 0; o < output_channel; o++) { + int oc8_block_num = o / C8NUM; + int oc8_block_rem = o % C8NUM; + int src_oc_offset = o * iC4 * C4NUM * kernel_plane; + int dst_oc_offset = oc8_block_num * C8NUM * iC4 * C4NUM * input_unit * input_unit + oc8_block_rem; + for (int i = 0; i < iC4; i++) { + float *src_ic4_ptr = weight_data + src_oc_offset + i * kernel_plane * C4NUM; + float *dst_ic4_ptr = trans_weight + dst_oc_offset + i * C8NUM * C4NUM; +#ifdef ENABLE_ARM + float32x4_t g00 = vld1q_f32(src_ic4_ptr); + float32x4_t g01 = vld1q_f32(src_ic4_ptr + 4); + float32x4_t g02 = vld1q_f32(src_ic4_ptr + 2 * 4); + float32x4_t g10 = vld1q_f32(src_ic4_ptr + 3 * 4); + float32x4_t g11 = vld1q_f32(src_ic4_ptr + 4 * 4); + float32x4_t g12 = vld1q_f32(src_ic4_ptr + 5 * 4); + float32x4_t g20 = vld1q_f32(src_ic4_ptr + 6 * 4); + float32x4_t g21 = vld1q_f32(src_ic4_ptr + 7 * 4); + float32x4_t g22 = vld1q_f32(src_ic4_ptr + 8 * 4); + + float32x4_t dst00 = g00; + float32x4_t dst01 = g01; + float32x4_t dst02 = g02; + + float32x4_t dst10 = vaddq_f32(vmulq_n_f32(g00, 0.5), vmulq_n_f32(g10, 0.5)); + dst10 = vaddq_f32(dst10, vmulq_n_f32(g20, 0.5)); + float32x4_t dst11 = vaddq_f32(vmulq_n_f32(g01, 0.5), vmulq_n_f32(g11, 0.5)); + dst11 = vaddq_f32(dst11, vmulq_n_f32(g21, 0.5)); + float32x4_t dst12 = vaddq_f32(vmulq_n_f32(g02, 0.5), vmulq_n_f32(g12, 0.5)); + dst12 = vaddq_f32(dst12, vmulq_n_f32(g22, 0.5)); + + float32x4_t dst20 = vsubq_f32(vmulq_n_f32(g00, 0.5), vmulq_n_f32(g10, 0.5)); + dst20 = vaddq_f32(dst20, vmulq_n_f32(g20, 0.5)); + float32x4_t dst21 = vsubq_f32(vmulq_n_f32(g01, 0.5), vmulq_n_f32(g11, 0.5)); + dst21 = vaddq_f32(dst21, vmulq_n_f32(g21, 0.5)); + float32x4_t dst22 = vsubq_f32(vmulq_n_f32(g02, 0.5), vmulq_n_f32(g12, 0.5)); + dst22 = vaddq_f32(dst22, vmulq_n_f32(g22, 0.5)); + + float32x4_t dst30 = g20; + float32x4_t dst31 = g21; + float32x4_t dst32 = g22; + + float32x4_t m00 = dst00; + float32x4_t m01 = vaddq_f32(vmulq_n_f32(dst00, 0.5), vmulq_n_f32(dst01, 0.5)); + m01 = vaddq_f32(m01, vmulq_n_f32(dst02, 0.5)); + float32x4_t m02 = vsubq_f32(vmulq_n_f32(dst00, 0.5), vmulq_n_f32(dst01, 0.5)); + m02 = vaddq_f32(m02, vmulq_n_f32(dst02, 0.5)); + float32x4_t m03 = dst02; + + float32x4_t m10 = dst10; + float32x4_t m11 = vaddq_f32(vmulq_n_f32(dst10, 0.5), vmulq_n_f32(dst11, 0.5)); + m11 = vaddq_f32(m11, vmulq_n_f32(dst12, 0.5)); + float32x4_t m12 = vsubq_f32(vmulq_n_f32(dst10, 0.5), vmulq_n_f32(dst11, 0.5)); + m12 = vaddq_f32(m12, vmulq_n_f32(dst12, 0.5)); + float32x4_t m13 = dst12; + + float32x4_t m20 = dst20; + float32x4_t m21 = vaddq_f32(vmulq_n_f32(dst20, 0.5), vmulq_n_f32(dst21, 0.5)); + m21 = vaddq_f32(m21, vmulq_n_f32(dst22, 0.5)); + float32x4_t m22 = vsubq_f32(vmulq_n_f32(dst20, 0.5), vmulq_n_f32(dst21, 0.5)); + m22 = vaddq_f32(m22, vmulq_n_f32(dst22, 0.5)); + float32x4_t m23 = dst22; + + float32x4_t m30 = dst30; + float32x4_t m31 = vaddq_f32(vmulq_n_f32(dst30, 0.5), vmulq_n_f32(dst31, 0.5)); + m31 = vaddq_f32(m31, vmulq_n_f32(dst32, 0.5)); + float32x4_t m32 = vsubq_f32(vmulq_n_f32(dst30, 0.5), vmulq_n_f32(dst31, 0.5)); + m32 = vaddq_f32(m32, vmulq_n_f32(dst32, 0.5)); + float32x4_t m33 = dst32; + + dst_ic4_ptr[0] = m00[0]; + dst_ic4_ptr[8] = m00[1]; + dst_ic4_ptr[16] = m00[2]; + dst_ic4_ptr[24] = m00[3]; + + dst_ic4_ptr[0 + dst_step] = m01[0]; + dst_ic4_ptr[8 + dst_step] = m01[1]; + dst_ic4_ptr[16 + dst_step] = m01[2]; + dst_ic4_ptr[24 + dst_step] = m01[3]; + + dst_ic4_ptr[0 + 2 * dst_step] = m02[0]; + dst_ic4_ptr[8 + 2 * dst_step] = m02[1]; + dst_ic4_ptr[16 + 2 * dst_step] = m02[2]; + dst_ic4_ptr[24 + 2 * dst_step] = m02[3]; + + dst_ic4_ptr[0 + 3 * dst_step] = m03[0]; + dst_ic4_ptr[8 + 3 * dst_step] = m03[1]; + dst_ic4_ptr[16 + 3 * dst_step] = m03[2]; + dst_ic4_ptr[24 + 3 * dst_step] = m03[3]; + + dst_ic4_ptr[0 + 4 * dst_step] = m10[0]; + dst_ic4_ptr[8 + 4 * dst_step] = m10[1]; + dst_ic4_ptr[16 + 4 * dst_step] = m10[2]; + dst_ic4_ptr[24 + 4 * dst_step] = m10[3]; + + dst_ic4_ptr[0 + 5 * dst_step] = m11[0]; + dst_ic4_ptr[8 + 5 * dst_step] = m11[1]; + dst_ic4_ptr[16 + 5 * dst_step] = m11[2]; + dst_ic4_ptr[24 + 5 * dst_step] = m11[3]; + + dst_ic4_ptr[0 + 6 * dst_step] = m12[0]; + dst_ic4_ptr[8 + 6 * dst_step] = m12[1]; + dst_ic4_ptr[16 + 6 * dst_step] = m12[2]; + dst_ic4_ptr[24 + 6 * dst_step] = m12[3]; + + dst_ic4_ptr[0 + 7 * dst_step] = m13[0]; + dst_ic4_ptr[8 + 7 * dst_step] = m13[1]; + dst_ic4_ptr[16 + 7 * dst_step] = m13[2]; + dst_ic4_ptr[24 + 7 * dst_step] = m13[3]; + + dst_ic4_ptr[0 + 8 * dst_step] = m20[0]; + dst_ic4_ptr[8 + 8 * dst_step] = m20[1]; + dst_ic4_ptr[16 + 8 * dst_step] = m20[2]; + dst_ic4_ptr[24 + 8 * dst_step] = m20[3]; + + dst_ic4_ptr[0 + 9 * dst_step] = m21[0]; + dst_ic4_ptr[8 + 9 * dst_step] = m21[1]; + dst_ic4_ptr[16 + 9 * dst_step] = m21[2]; + dst_ic4_ptr[24 + 9 * dst_step] = m21[3]; + + dst_ic4_ptr[0 + 10 * dst_step] = m22[0]; + dst_ic4_ptr[8 + 10 * dst_step] = m22[1]; + dst_ic4_ptr[16 + 10 * dst_step] = m22[2]; + dst_ic4_ptr[24 + 10 * dst_step] = m22[3]; + + dst_ic4_ptr[0 + 11 * dst_step] = m23[0]; + dst_ic4_ptr[8 + 11 * dst_step] = m23[1]; + dst_ic4_ptr[16 + 11 * dst_step] = m23[2]; + dst_ic4_ptr[24 + 11 * dst_step] = m23[3]; + + dst_ic4_ptr[0 + 12 * dst_step] = m30[0]; + dst_ic4_ptr[8 + 12 * dst_step] = m30[1]; + dst_ic4_ptr[16 + 12 * dst_step] = m30[2]; + dst_ic4_ptr[24 + 12 * dst_step] = m30[3]; + + dst_ic4_ptr[0 + 13 * dst_step] = m31[0]; + dst_ic4_ptr[8 + 13 * dst_step] = m31[1]; + dst_ic4_ptr[16 + 13 * dst_step] = m31[2]; + dst_ic4_ptr[24 + 13 * dst_step] = m31[3]; + + dst_ic4_ptr[0 + 14 * dst_step] = m32[0]; + dst_ic4_ptr[8 + 14 * dst_step] = m32[1]; + dst_ic4_ptr[16 + 14 * dst_step] = m32[2]; + dst_ic4_ptr[24 + 14 * dst_step] = m32[3]; + + dst_ic4_ptr[0 + 15 * dst_step] = m33[0]; + dst_ic4_ptr[8 + 15 * dst_step] = m33[1]; + dst_ic4_ptr[16 + 15 * dst_step] = m33[2]; + dst_ic4_ptr[24 + 15 * dst_step] = m33[3]; +#else + for (int j = 0; j < C4NUM; j++) { + float *local_ptr = src_ic4_ptr + j; + float dst00 = local_ptr[0]; + float dst01 = (local_ptr + 4)[0]; + float dst02 = (local_ptr + 8)[0]; + + float dst10 = 0.5f * local_ptr[0] + 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0]; + float dst11 = 0.5f * (local_ptr + 4)[0] + 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0]; + float dst12 = 0.5f * (local_ptr + 8)[0] + 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0]; + + float dst20 = 0.5f * local_ptr[0] - 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0]; + float dst21 = 0.5f * (local_ptr + 4)[0] - 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0]; + float dst22 = 0.5f * (local_ptr + 8)[0] - 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0]; + + float dst30 = (local_ptr + 24)[0]; + float dst31 = (local_ptr + 28)[0]; + float dst32 = (local_ptr + 32)[0]; + + float m00 = dst00; + float m01 = 0.5f * dst00 + 0.5f * dst01 + 0.5f * dst02; + float m02 = 0.5f * dst00 - 0.5f * dst01 + 0.5f * dst02; + float m03 = dst02; + + float m10 = dst10; + float m11 = 0.5f * dst10 + 0.5f * dst11 + 0.5f * dst12; + float m12 = 0.5f * dst10 - 0.5f * dst11 + 0.5f * dst12; + float m13 = dst12; + + float m20 = dst20; + float m21 = 0.5f * dst20 + 0.5f * dst21 + 0.5f * dst22; + float m22 = 0.5f * dst20 - 0.5f * dst21 + 0.5f * dst22; + float m23 = dst22; + + float m30 = dst30; + float m31 = 0.5f * dst30 + 0.5f * dst31 + 0.5f * dst32; + float m32 = 0.5f * dst30 - 0.5f * dst31 + 0.5f * dst32; + float m33 = dst32; + + *(dst_ic4_ptr + j * 8) = m00; + *(dst_ic4_ptr + j * 8 + dst_step) = m01; + *(dst_ic4_ptr + j * 8 + 2 * dst_step) = m02; + *(dst_ic4_ptr + j * 8 + 3 * dst_step) = m03; + + *(dst_ic4_ptr + j * 8 + 4 * dst_step) = m10; + *(dst_ic4_ptr + j * 8 + 5 * dst_step) = m11; + *(dst_ic4_ptr + j * 8 + 6 * dst_step) = m12; + *(dst_ic4_ptr + j * 8 + 7 * dst_step) = m13; + + *(dst_ic4_ptr + j * 8 + 8 * dst_step) = m20; + *(dst_ic4_ptr + j * 8 + 9 * dst_step) = m21; + *(dst_ic4_ptr + j * 8 + 10 * dst_step) = m22; + *(dst_ic4_ptr + j * 8 + 11 * dst_step) = m23; + + *(dst_ic4_ptr + j * 8 + 12 * dst_step) = m30; + *(dst_ic4_ptr + j * 8 + 13 * dst_step) = m31; + *(dst_ic4_ptr + j * 8 + 14 * dst_step) = m32; + *(dst_ic4_ptr + j * 8 + 15 * dst_step) = m33; + } +#endif + } + } +} + +void Conv3x3Fp32OutputUnit(const float *gemm_out, const float *bias_data, float *output_data, bool h_not_bound, + bool w_not_bound, int output_w) { +#ifdef ENABLE_ARM + float32x4_t bias_ptr = vld1q_f32(bias_data); + + float32x4_t s00 = vld1q_f32(gemm_out); + float32x4_t s01 = vld1q_f32(gemm_out + 4); + float32x4_t s02 = vld1q_f32(gemm_out + 8); + float32x4_t s03 = vld1q_f32(gemm_out + 12); + + float32x4_t s10 = vld1q_f32(gemm_out + 16); + float32x4_t s11 = vld1q_f32(gemm_out + 20); + float32x4_t s12 = vld1q_f32(gemm_out + 24); + float32x4_t s13 = vld1q_f32(gemm_out + 28); + + float32x4_t s20 = vld1q_f32(gemm_out + 32); + float32x4_t s21 = vld1q_f32(gemm_out + 36); + float32x4_t s22 = vld1q_f32(gemm_out + 40); + float32x4_t s23 = vld1q_f32(gemm_out + 44); + + float32x4_t s30 = vld1q_f32(gemm_out + 48); + float32x4_t s31 = vld1q_f32(gemm_out + 52); + float32x4_t s32 = vld1q_f32(gemm_out + 56); + float32x4_t s33 = vld1q_f32(gemm_out + 60); + + float32x4_t t00 = vaddq_f32(vaddq_f32(s00, s10), s20); + float32x4_t t01 = vaddq_f32(vaddq_f32(s01, s11), s21); + float32x4_t t02 = vaddq_f32(vaddq_f32(s02, s12), s22); + float32x4_t t03 = vaddq_f32(vaddq_f32(s03, s13), s23); + + float32x4_t t10 = vsubq_f32(vsubq_f32(s10, s20), s30); + float32x4_t t11 = vsubq_f32(vsubq_f32(s11, s21), s31); + float32x4_t t12 = vsubq_f32(vsubq_f32(s12, s22), s32); + float32x4_t t13 = vsubq_f32(vsubq_f32(s13, s23), s33); + + float32x4_t d00 = vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), bias_ptr); + float32x4_t d01 = vaddq_f32(vsubq_f32(vsubq_f32(t01, t02), t03), bias_ptr); + float32x4_t d10 = vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), bias_ptr); + float32x4_t d11 = vaddq_f32(vsubq_f32(vsubq_f32(t11, t12), t13), bias_ptr); + + vst1q_f32(output_data, d00); + if (w_not_bound) { + vst1q_f32(output_data + 4, d01); + } + if (h_not_bound) { + vst1q_f32(output_data + output_w * 4, d10); + if (w_not_bound) { + vst1q_f32(output_data + output_w * 4 + 4, d11); + } + } +#else + for (int i = 0; i < C4NUM; i++) { + const float *local_ptr = gemm_out + i; + const float *bias_ptr = bias_data + i; + + float s00 = local_ptr[0]; + float s01 = (local_ptr + 4)[0]; + float s02 = (local_ptr + 8)[0]; + float s03 = (local_ptr + 12)[0]; + + float s10 = (local_ptr + 16)[0]; + float s11 = (local_ptr + 20)[0]; + float s12 = (local_ptr + 24)[0]; + float s13 = (local_ptr + 28)[0]; + + float s20 = (local_ptr + 32)[0]; + float s21 = (local_ptr + 36)[0]; + float s22 = (local_ptr + 40)[0]; + float s23 = (local_ptr + 44)[0]; + + float s30 = (local_ptr + 48)[0]; + float s31 = (local_ptr + 52)[0]; + float s32 = (local_ptr + 56)[0]; + float s33 = (local_ptr + 60)[0]; + + float t00 = s00 + s10 + s20; + float t01 = s01 + s11 + s21; + float t02 = s02 + s12 + s22; + float t03 = s03 + s13 + s23; + + float t10 = s10 - s20 - s30; + float t11 = s11 - s21 - s31; + float t12 = s12 - s22 - s32; + float t13 = s13 - s23 - s33; + + float d00 = t00 + t01 + t02 + bias_ptr[0]; + float d01 = t01 - t02 - t03 + bias_ptr[0]; + float d10 = t10 + t11 + t12 + bias_ptr[0]; + float d11 = t11 - t12 - t13 + bias_ptr[0]; + + (output_data + i)[0] = d00; + if (w_not_bound) { + (output_data + i + C4NUM)[0] = d01; + } + if (h_not_bound) { + (output_data + i + output_w * C4NUM)[0] = d10; + if (w_not_bound) { + (output_data + i + output_w * C4NUM + C4NUM)[0] = d11; + } + } + } +#endif +} + +void Conv3x3Fp32OutputTransform(const float *gemm_out, float *out_data, const float *bias_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param) { + int output_channel = conv_param->output_channel_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int oc4 = UP_DIV(output_channel, C4NUM); + int input_unit = 4; + + for (int i = 0; i < real_cal_num; i++) { + int out_w_index = (start_index + i) % out_w_block; + int out_h_index = (start_index + i) / out_w_block; + int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit; + int dst_tile_offset = C4NUM * (out_w_index * OUPUT_UNIT + out_h_index * OUPUT_UNIT * output_w); + + for (int j = 0; j < oc4; j++) { + int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM; + int dst_oc4_offset = dst_tile_offset + j * C4NUM * output_h * output_w; + const float *src_ptr = gemm_out + src_oc4_offset; + const float *bias_ptr = bias_data + j * C4NUM; + float *dst_ptr = out_data + dst_oc4_offset; + + // output transform + bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w; + bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h; + Conv3x3Fp32OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w); + } + } +} + +#ifdef ENABLE_FP16 +// for fp16 convolution 3x3 filter/input/output transform F(4,3) +void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step) { + float16x4_t d00 = vld1_f16(tmp_data); + float16x4_t d01 = vld1_f16(tmp_data + 4); + float16x4_t d02 = vld1_f16(tmp_data + 2 * 4); + float16x4_t d03 = vld1_f16(tmp_data + 3 * 4); + float16x4_t d04 = vld1_f16(tmp_data + 4 * 4); + float16x4_t d05 = vld1_f16(tmp_data + 5 * 4); + + float16x4_t d10 = vld1_f16(tmp_data + 6 * 4); + float16x4_t d11 = vld1_f16(tmp_data + 7 * 4); + float16x4_t d12 = vld1_f16(tmp_data + 8 * 4); + float16x4_t d13 = vld1_f16(tmp_data + 9 * 4); + float16x4_t d14 = vld1_f16(tmp_data + 10 * 4); + float16x4_t d15 = vld1_f16(tmp_data + 11 * 4); + + float16x4_t d20 = vld1_f16(tmp_data + 12 * 4); + float16x4_t d21 = vld1_f16(tmp_data + 13 * 4); + float16x4_t d22 = vld1_f16(tmp_data + 14 * 4); + float16x4_t d23 = vld1_f16(tmp_data + 15 * 4); + float16x4_t d24 = vld1_f16(tmp_data + 16 * 4); + float16x4_t d25 = vld1_f16(tmp_data + 17 * 4); + + float16x4_t d30 = vld1_f16(tmp_data + 18 * 4); + float16x4_t d31 = vld1_f16(tmp_data + 19 * 4); + float16x4_t d32 = vld1_f16(tmp_data + 20 * 4); + float16x4_t d33 = vld1_f16(tmp_data + 21 * 4); + float16x4_t d34 = vld1_f16(tmp_data + 22 * 4); + float16x4_t d35 = vld1_f16(tmp_data + 23 * 4); + + float16x4_t d40 = vld1_f16(tmp_data + 24 * 4); + float16x4_t d41 = vld1_f16(tmp_data + 25 * 4); + float16x4_t d42 = vld1_f16(tmp_data + 26 * 4); + float16x4_t d43 = vld1_f16(tmp_data + 27 * 4); + float16x4_t d44 = vld1_f16(tmp_data + 28 * 4); + float16x4_t d45 = vld1_f16(tmp_data + 29 * 4); + + float16x4_t d50 = vld1_f16(tmp_data + 30 * 4); + float16x4_t d51 = vld1_f16(tmp_data + 31 * 4); + float16x4_t d52 = vld1_f16(tmp_data + 32 * 4); + float16x4_t d53 = vld1_f16(tmp_data + 33 * 4); + float16x4_t d54 = vld1_f16(tmp_data + 34 * 4); + float16x4_t d55 = vld1_f16(tmp_data + 35 * 4); + + float16x4_t t00 = vadd_f16(vsub_f16(vmul_n_f16(d00, 4), vmul_n_f16(d20, 5)), d40); + float16x4_t t01 = vadd_f16(vsub_f16(vmul_n_f16(d01, 4), vmul_n_f16(d21, 5)), d41); + float16x4_t t02 = vadd_f16(vsub_f16(vmul_n_f16(d02, 4), vmul_n_f16(d22, 5)), d42); + float16x4_t t03 = vadd_f16(vsub_f16(vmul_n_f16(d03, 4), vmul_n_f16(d23, 5)), d43); + float16x4_t t04 = vadd_f16(vsub_f16(vmul_n_f16(d04, 4), vmul_n_f16(d24, 5)), d44); + float16x4_t t05 = vadd_f16(vsub_f16(vmul_n_f16(d05, 4), vmul_n_f16(d25, 5)), d45); + + float16x4_t t10 = vadd_f16(vadd_f16(d30, d40), vmul_n_f16(vadd_f16(d10, d20), -4)); + float16x4_t t11 = vadd_f16(vadd_f16(d31, d41), vmul_n_f16(vadd_f16(d11, d21), -4)); + float16x4_t t12 = vadd_f16(vadd_f16(d32, d42), vmul_n_f16(vadd_f16(d12, d22), -4)); + float16x4_t t13 = vadd_f16(vadd_f16(d33, d43), vmul_n_f16(vadd_f16(d13, d23), -4)); + float16x4_t t14 = vadd_f16(vadd_f16(d34, d44), vmul_n_f16(vadd_f16(d14, d24), -4)); + float16x4_t t15 = vadd_f16(vadd_f16(d35, d45), vmul_n_f16(vadd_f16(d15, d25), -4)); + + float16x4_t t20 = vadd_f16(vsub_f16(d40, d30), vmul_n_f16(vsub_f16(d10, d20), 4)); + float16x4_t t21 = vadd_f16(vsub_f16(d41, d31), vmul_n_f16(vsub_f16(d11, d21), 4)); + float16x4_t t22 = vadd_f16(vsub_f16(d42, d32), vmul_n_f16(vsub_f16(d12, d22), 4)); + float16x4_t t23 = vadd_f16(vsub_f16(d43, d33), vmul_n_f16(vsub_f16(d13, d23), 4)); + float16x4_t t24 = vadd_f16(vsub_f16(d44, d34), vmul_n_f16(vsub_f16(d14, d24), 4)); + float16x4_t t25 = vadd_f16(vsub_f16(d45, d35), vmul_n_f16(vsub_f16(d15, d25), 4)); + + float16x4_t t30 = vadd_f16(vsub_f16(d40, d20), vmul_n_f16(vsub_f16(d30, d10), 2)); + float16x4_t t31 = vadd_f16(vsub_f16(d41, d21), vmul_n_f16(vsub_f16(d31, d11), 2)); + float16x4_t t32 = vadd_f16(vsub_f16(d42, d22), vmul_n_f16(vsub_f16(d32, d12), 2)); + float16x4_t t33 = vadd_f16(vsub_f16(d43, d23), vmul_n_f16(vsub_f16(d33, d13), 2)); + float16x4_t t34 = vadd_f16(vsub_f16(d44, d24), vmul_n_f16(vsub_f16(d34, d14), 2)); + float16x4_t t35 = vadd_f16(vsub_f16(d45, d25), vmul_n_f16(vsub_f16(d35, d15), 2)); + + float16x4_t t40 = vadd_f16(vsub_f16(d40, d20), vmul_n_f16(vsub_f16(d10, d30), 2)); + float16x4_t t41 = vadd_f16(vsub_f16(d41, d21), vmul_n_f16(vsub_f16(d11, d31), 2)); + float16x4_t t42 = vadd_f16(vsub_f16(d42, d22), vmul_n_f16(vsub_f16(d12, d32), 2)); + float16x4_t t43 = vadd_f16(vsub_f16(d43, d23), vmul_n_f16(vsub_f16(d13, d33), 2)); + float16x4_t t44 = vadd_f16(vsub_f16(d44, d24), vmul_n_f16(vsub_f16(d14, d34), 2)); + float16x4_t t45 = vadd_f16(vsub_f16(d45, d25), vmul_n_f16(vsub_f16(d15, d35), 2)); + + float16x4_t t50 = vadd_f16(vsub_f16(vmul_n_f16(d10, 4), vmul_n_f16(d30, 5)), d50); + float16x4_t t51 = vadd_f16(vsub_f16(vmul_n_f16(d11, 4), vmul_n_f16(d31, 5)), d51); + float16x4_t t52 = vadd_f16(vsub_f16(vmul_n_f16(d12, 4), vmul_n_f16(d32, 5)), d52); + float16x4_t t53 = vadd_f16(vsub_f16(vmul_n_f16(d13, 4), vmul_n_f16(d33, 5)), d53); + float16x4_t t54 = vadd_f16(vsub_f16(vmul_n_f16(d14, 4), vmul_n_f16(d34, 5)), d54); + float16x4_t t55 = vadd_f16(vsub_f16(vmul_n_f16(d15, 4), vmul_n_f16(d35, 5)), d55); + + float16x4_t m00 = vadd_f16(vsub_f16(vmul_n_f16(t00, 4), vmul_n_f16(t02, 5)), t04); + float16x4_t m01 = vadd_f16(vadd_f16(t03, t04), vmul_n_f16(vadd_f16(t01, t02), -4)); + float16x4_t m02 = vadd_f16(vsub_f16(t04, t03), vmul_n_f16(vsub_f16(t01, t02), 4)); + float16x4_t m03 = vadd_f16(vsub_f16(t04, t02), vmul_n_f16(vsub_f16(t03, t01), 2)); + float16x4_t m04 = vadd_f16(vsub_f16(t04, t02), vmul_n_f16(vsub_f16(t01, t03), 2)); + float16x4_t m05 = vadd_f16(vsub_f16(vmul_n_f16(t01, 4), vmul_n_f16(t03, 5)), t05); + + float16x4_t m10 = vadd_f16(vsub_f16(vmul_n_f16(t10, 4), vmul_n_f16(t12, 5)), t14); + float16x4_t m11 = vadd_f16(vadd_f16(t13, t14), vmul_n_f16(vadd_f16(t11, t12), -4)); + float16x4_t m12 = vadd_f16(vsub_f16(t14, t13), vmul_n_f16(vsub_f16(t11, t12), 4)); + float16x4_t m13 = vadd_f16(vsub_f16(t14, t12), vmul_n_f16(vsub_f16(t13, t11), 2)); + float16x4_t m14 = vadd_f16(vsub_f16(t14, t12), vmul_n_f16(vsub_f16(t11, t13), 2)); + float16x4_t m15 = vadd_f16(vsub_f16(vmul_n_f16(t11, 4), vmul_n_f16(t13, 5)), t15); + + float16x4_t m20 = vadd_f16(vsub_f16(vmul_n_f16(t20, 4), vmul_n_f16(t22, 5)), t24); + float16x4_t m21 = vadd_f16(vadd_f16(t23, t24), vmul_n_f16(vadd_f16(t21, t22), -4)); + float16x4_t m22 = vadd_f16(vsub_f16(t24, t23), vmul_n_f16(vsub_f16(t21, t22), 4)); + float16x4_t m23 = vadd_f16(vsub_f16(t24, t22), vmul_n_f16(vsub_f16(t23, t21), 2)); + float16x4_t m24 = vadd_f16(vsub_f16(t24, t22), vmul_n_f16(vsub_f16(t21, t23), 2)); + float16x4_t m25 = vadd_f16(vsub_f16(vmul_n_f16(t21, 4), vmul_n_f16(t23, 5)), t25); + + float16x4_t m30 = vadd_f16(vsub_f16(vmul_n_f16(t30, 4), vmul_n_f16(t32, 5)), t34); + float16x4_t m31 = vadd_f16(vadd_f16(t33, t34), vmul_n_f16(vadd_f16(t31, t32), -4)); + float16x4_t m32 = vadd_f16(vsub_f16(t34, t33), vmul_n_f16(vsub_f16(t31, t32), 4)); + float16x4_t m33 = vadd_f16(vsub_f16(t34, t32), vmul_n_f16(vsub_f16(t33, t31), 2)); + float16x4_t m34 = vadd_f16(vsub_f16(t34, t32), vmul_n_f16(vsub_f16(t31, t33), 2)); + float16x4_t m35 = vadd_f16(vsub_f16(vmul_n_f16(t31, 4), vmul_n_f16(t33, 5)), t35); + + float16x4_t m40 = vadd_f16(vsub_f16(vmul_n_f16(t40, 4), vmul_n_f16(t42, 5)), t44); + float16x4_t m41 = vadd_f16(vadd_f16(t43, t44), vmul_n_f16(vadd_f16(t41, t42), -4)); + float16x4_t m42 = vadd_f16(vsub_f16(t44, t43), vmul_n_f16(vsub_f16(t41, t42), 4)); + float16x4_t m43 = vadd_f16(vsub_f16(t44, t42), vmul_n_f16(vsub_f16(t43, t41), 2)); + float16x4_t m44 = vadd_f16(vsub_f16(t44, t42), vmul_n_f16(vsub_f16(t41, t43), 2)); + float16x4_t m45 = vadd_f16(vsub_f16(vmul_n_f16(t41, 4), vmul_n_f16(t43, 5)), t45); + + float16x4_t m50 = vadd_f16(vsub_f16(vmul_n_f16(t50, 4), vmul_n_f16(t52, 5)), t54); + float16x4_t m51 = vadd_f16(vadd_f16(t53, t54), vmul_n_f16(vadd_f16(t51, t52), -4)); + float16x4_t m52 = vadd_f16(vsub_f16(t54, t53), vmul_n_f16(vsub_f16(t51, t52), 4)); + float16x4_t m53 = vadd_f16(vsub_f16(t54, t52), vmul_n_f16(vsub_f16(t53, t51), 2)); + float16x4_t m54 = vadd_f16(vsub_f16(t54, t52), vmul_n_f16(vsub_f16(t51, t53), 2)); + float16x4_t m55 = vadd_f16(vsub_f16(vmul_n_f16(t51, 4), vmul_n_f16(t53, 5)), t55); + + vst1_f16(trans_input_data, m00); + vst1_f16(trans_input_data + step, m01); + vst1_f16(trans_input_data + 2 * step, m02); + vst1_f16(trans_input_data + 3 * step, m03); + vst1_f16(trans_input_data + 4 * step, m04); + vst1_f16(trans_input_data + 5 * step, m05); + + vst1_f16(trans_input_data + 6 * step, m10); + vst1_f16(trans_input_data + 7 * step, m11); + vst1_f16(trans_input_data + 8 * step, m12); + vst1_f16(trans_input_data + 9 * step, m13); + vst1_f16(trans_input_data + 10 * step, m14); + vst1_f16(trans_input_data + 11 * step, m15); + + vst1_f16(trans_input_data + 12 * step, m20); + vst1_f16(trans_input_data + 13 * step, m21); + vst1_f16(trans_input_data + 14 * step, m22); + vst1_f16(trans_input_data + 15 * step, m23); + vst1_f16(trans_input_data + 16 * step, m24); + vst1_f16(trans_input_data + 17 * step, m25); + + vst1_f16(trans_input_data + 18 * step, m30); + vst1_f16(trans_input_data + 19 * step, m31); + vst1_f16(trans_input_data + 20 * step, m32); + vst1_f16(trans_input_data + 21 * step, m33); + vst1_f16(trans_input_data + 22 * step, m34); + vst1_f16(trans_input_data + 23 * step, m35); + + vst1_f16(trans_input_data + 24 * step, m40); + vst1_f16(trans_input_data + 25 * step, m41); + vst1_f16(trans_input_data + 26 * step, m42); + vst1_f16(trans_input_data + 27 * step, m43); + vst1_f16(trans_input_data + 28 * step, m44); + vst1_f16(trans_input_data + 29 * step, m45); + + vst1_f16(trans_input_data + 30 * step, m50); + vst1_f16(trans_input_data + 31 * step, m51); + vst1_f16(trans_input_data + 32 * step, m52); + vst1_f16(trans_input_data + 33 * step, m53); + vst1_f16(trans_input_data + 34 * step, m54); + vst1_f16(trans_input_data + 35 * step, m55); +} + +void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, + int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) { + // input data format : nhwc + int output_unit = 4; + int input_channel = conv_param->input_channel_; + int input_width = conv_param->input_w_; + int input_height = conv_param->input_h_; + int pad_w = conv_param->pad_w_; + int pad_h = conv_param->pad_h_; + int ic4 = UP_DIV(input_channel, C4NUM); + + for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { + int x_id = start_index + cal_id; + int origin_x = (x_id % out_w_block) * output_unit - pad_w; + int origin_y = (x_id / out_w_block) * output_unit - pad_h; + int real_x_start = origin_x > 0 ? 0 : -origin_x; + int real_x_end = (origin_x + 6) < input_width ? 6 : (input_width - origin_x); + int real_y_start = origin_y > 0 ? 0 : -origin_y; + int real_y_end = (origin_y + 6) < input_height ? 6 : (input_height - origin_y); + + int src_plane_offset = input_channel * (origin_y * input_width + origin_x); + int dst_plane_offset = cal_id * C4NUM; + for (int ic = 0; ic < ic4; ic++) { + // clear tmp buffer + memset(tmp_data, 0, 6 * 6 * C4NUM * sizeof(float16_t)); + + // get real input block with padding + int src_ic4_offset = src_plane_offset + ic * C4NUM; + for (int interval = real_y_start; interval < real_y_end; interval++) { + int src_y_offset = src_ic4_offset + interval * input_width * input_channel + real_x_start * input_channel; + int dst_y_offset = interval * 6 * C4NUM + real_x_start * C4NUM; + for (int j = 0; j < (real_x_end - real_x_start); j++) { + int src_x_offset = src_y_offset + j * input_channel; + int dst_x_offset = dst_y_offset + j * C4NUM; + float16_t *src_addr = (float16_t *)(input_data) + src_x_offset; + float16_t *dst_addr = tmp_data + dst_x_offset; + dst_addr[0] = src_addr[0]; + dst_addr[1] = src_addr[1]; + dst_addr[2] = src_addr[2]; + dst_addr[3] = src_addr[3]; + } + } + + // todo + // input transform + int dst_ic4_offset = dst_plane_offset + ic * 16 * C4NUM; + size_t dst_step = ic4 * C4NUM * 16; + float16_t *trans_input_ptr = trans_input + dst_ic4_offset; + Conv3x3Fp16InputUnit(tmp_data, trans_input_ptr, dst_step); + } + } +} + +void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC4, int output_channel, + int kernel_plane) { + int dst_step = iC4 * C4NUM * 8; + for (int o = 0; o < output_channel; o++) { + int oc8_block_num = o / C8NUM; + int oc8_block_rem = o % C8NUM; + int src_oc_offset = o * iC4 * C4NUM * kernel_plane; + int dst_oc_offset = oc8_block_num * C8NUM * iC4 * C4NUM * 36 + oc8_block_rem; + for (int i = 0; i < iC4; i++) { + const float16_t *src_ic4_ptr = weight_data + src_oc_offset + i * kernel_plane * C4NUM; + float16_t *dst_ic4_ptr = trans_weight + dst_oc_offset + i * 8 * C4NUM; + float16x4_t g00 = vld1_f16(src_ic4_ptr); + float16x4_t g01 = vld1_f16(src_ic4_ptr + 4); + float16x4_t g02 = vld1_f16(src_ic4_ptr + 2 * 4); + float16x4_t g10 = vld1_f16(src_ic4_ptr + 3 * 4); + float16x4_t g11 = vld1_f16(src_ic4_ptr + 4 * 4); + float16x4_t g12 = vld1_f16(src_ic4_ptr + 5 * 4); + float16x4_t g20 = vld1_f16(src_ic4_ptr + 6 * 4); + float16x4_t g21 = vld1_f16(src_ic4_ptr + 7 * 4); + float16x4_t g22 = vld1_f16(src_ic4_ptr + 8 * 4); + + float16x4_t dst00 = vmul_n_f16(g00, 0.25); + float16x4_t dst01 = vmul_n_f16(g01, 0.25); + float16x4_t dst02 = vmul_n_f16(g02, 0.25); + + float16x4_t dst10 = vmul_n_f16(vadd_f16(g00, vadd_f16(g10, g20)), -0.1666666666667); + float16x4_t dst11 = vmul_n_f16(vadd_f16(g01, vadd_f16(g11, g21)), -0.1666666666667); + float16x4_t dst12 = vmul_n_f16(vadd_f16(g02, vadd_f16(g12, g22)), -0.1666666666667); + + float16x4_t dst20 = vmul_n_f16(vsub_f16(vadd_f16(g00, g20), g10), -0.1666666666667); + float16x4_t dst21 = vmul_n_f16(vsub_f16(vadd_f16(g01, g21), g11), -0.1666666666667); + float16x4_t dst22 = vmul_n_f16(vsub_f16(vadd_f16(g02, g22), g12), -0.1666666666667); + + float16x4_t dst30 = vadd_f16(vmul_n_f16(g10, 0.08333333333333), + vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667))); + float16x4_t dst31 = vadd_f16(vmul_n_f16(g11, 0.08333333333333), + vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667))); + float16x4_t dst32 = vadd_f16(vmul_n_f16(g12, 0.08333333333333), + vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667))); + + float16x4_t dst40 = vsub_f16(vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667)), + vmul_n_f16(g10, 0.08333333333333)); + float16x4_t dst41 = vsub_f16(vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667)), + vmul_n_f16(g11, 0.08333333333333)); + float16x4_t dst42 = vsub_f16(vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667)), + vmul_n_f16(g12, 0.08333333333333)); + + float16x4_t dst50 = g20; + float16x4_t dst51 = g21; + float16x4_t dst52 = g22; + + float16x4_t m00 = vmul_n_f16(dst00, 0.25); + float16x4_t m01 = vmul_n_f16(vadd_f16(dst00, vadd_f16(dst01, dst02)), -0.1666666666667); + float16x4_t m02 = vmul_n_f16(vsub_f16(vadd_f16(dst00, dst02), dst01), -0.1666666666667); + float16x4_t m03 = vadd_f16(vmul_n_f16(dst01, 0.08333333333333), + vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667))); + float16x4_t m04 = vsub_f16(vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667)), + vmul_n_f16(dst01, 0.08333333333333)); + float16x4_t m05 = dst02; + + float16x4_t m10 = vmul_n_f16(dst10, 0.25); + float16x4_t m11 = vmul_n_f16(vadd_f16(dst10, vadd_f16(dst11, dst12)), -0.1666666666667); + float16x4_t m12 = vmul_n_f16(vsub_f16(vadd_f16(dst10, dst12), dst11), -0.1666666666667); + float16x4_t m13 = vadd_f16(vmul_n_f16(dst11, 0.08333333333333), + vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667))); + float16x4_t m14 = vsub_f16(vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667)), + vmul_n_f16(dst11, 0.08333333333333)); + float16x4_t m15 = dst12; + + float16x4_t m20 = vmul_n_f16(dst20, 0.25); + float16x4_t m21 = vmul_n_f16(vadd_f16(dst20, vadd_f16(dst21, dst22)), -0.1666666666667); + float16x4_t m22 = vmul_n_f16(vsub_f16(vadd_f16(dst20, dst22), dst21), -0.1666666666667); + float16x4_t m23 = vadd_f16(vmul_n_f16(dst21, 0.08333333333333), + vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667))); + float16x4_t m24 = vsub_f16(vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667)), + vmul_n_f16(dst21, 0.08333333333333)); + float16x4_t m25 = dst22; + + float16x4_t m30 = vmul_n_f16(dst30, 0.25); + float16x4_t m31 = vmul_n_f16(vadd_f16(dst30, vadd_f16(dst31, dst32)), -0.1666666666667); + float16x4_t m32 = vmul_n_f16(vsub_f16(vadd_f16(dst30, dst32), dst31), -0.1666666666667); + float16x4_t m33 = vadd_f16(vmul_n_f16(dst31, 0.08333333333333), + vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667))); + float16x4_t m34 = vsub_f16(vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667)), + vmul_n_f16(dst31, 0.08333333333333)); + float16x4_t m35 = dst32; + + float16x4_t m40 = vmul_n_f16(dst40, 0.25); + float16x4_t m41 = vmul_n_f16(vadd_f16(dst40, vadd_f16(dst41, dst42)), -0.1666666666667); + float16x4_t m42 = vmul_n_f16(vsub_f16(vadd_f16(dst40, dst42), dst41), -0.1666666666667); + float16x4_t m43 = vadd_f16(vmul_n_f16(dst41, 0.08333333333333), + vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667))); + float16x4_t m44 = vsub_f16(vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667)), + vmul_n_f16(dst41, 0.08333333333333)); + float16x4_t m45 = dst42; + + float16x4_t m50 = vmul_n_f16(dst50, 0.25); + float16x4_t m51 = vmul_n_f16(vadd_f16(dst50, vadd_f16(dst51, dst52)), -0.1666666666667); + float16x4_t m52 = vmul_n_f16(vsub_f16(vadd_f16(dst50, dst52), dst51), -0.1666666666667); + float16x4_t m53 = vadd_f16(vmul_n_f16(dst51, 0.08333333333333), + vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667))); + float16x4_t m54 = vsub_f16(vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667)), + vmul_n_f16(dst51, 0.08333333333333)); + float16x4_t m55 = dst52; + + for (int j = 0; j < 4; j++) { + dst_ic4_ptr[j * 8] = m00[j]; + dst_ic4_ptr[j * 8 + dst_step] = m01[j]; + dst_ic4_ptr[j * 8 + 2 * dst_step] = m02[j]; + dst_ic4_ptr[j * 8 + 3 * dst_step] = m03[j]; + dst_ic4_ptr[j * 8 + 4 * dst_step] = m04[j]; + dst_ic4_ptr[j * 8 + 5 * dst_step] = m05[j]; + dst_ic4_ptr[j * 8 + 6 * dst_step] = m10[j]; + dst_ic4_ptr[j * 8 + 7 * dst_step] = m11[j]; + dst_ic4_ptr[j * 8 + 8 * dst_step] = m12[j]; + dst_ic4_ptr[j * 8 + 9 * dst_step] = m13[j]; + dst_ic4_ptr[j * 8 + 10 * dst_step] = m14[j]; + dst_ic4_ptr[j * 8 + 11 * dst_step] = m15[j]; + dst_ic4_ptr[j * 8 + 12 * dst_step] = m20[j]; + dst_ic4_ptr[j * 8 + 13 * dst_step] = m21[j]; + dst_ic4_ptr[j * 8 + 14 * dst_step] = m22[j]; + dst_ic4_ptr[j * 8 + 15 * dst_step] = m23[j]; + dst_ic4_ptr[j * 8 + 16 * dst_step] = m24[j]; + dst_ic4_ptr[j * 8 + 17 * dst_step] = m25[j]; + dst_ic4_ptr[j * 8 + 18 * dst_step] = m30[j]; + dst_ic4_ptr[j * 8 + 19 * dst_step] = m31[j]; + dst_ic4_ptr[j * 8 + 20 * dst_step] = m32[j]; + dst_ic4_ptr[j * 8 + 21 * dst_step] = m33[j]; + dst_ic4_ptr[j * 8 + 22 * dst_step] = m34[j]; + dst_ic4_ptr[j * 8 + 23 * dst_step] = m35[j]; + dst_ic4_ptr[j * 8 + 24 * dst_step] = m40[j]; + dst_ic4_ptr[j * 8 + 25 * dst_step] = m41[j]; + dst_ic4_ptr[j * 8 + 26 * dst_step] = m42[j]; + dst_ic4_ptr[j * 8 + 27 * dst_step] = m43[j]; + dst_ic4_ptr[j * 8 + 28 * dst_step] = m44[j]; + dst_ic4_ptr[j * 8 + 29 * dst_step] = m45[j]; + dst_ic4_ptr[j * 8 + 30 * dst_step] = m50[j]; + dst_ic4_ptr[j * 8 + 31 * dst_step] = m51[j]; + dst_ic4_ptr[j * 8 + 32 * dst_step] = m52[j]; + dst_ic4_ptr[j * 8 + 33 * dst_step] = m53[j]; + dst_ic4_ptr[j * 8 + 34 * dst_step] = m54[j]; + dst_ic4_ptr[j * 8 + 35 * dst_step] = m55[j]; + } + } + } +} + +void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, + int output_w) { + float16x8_t s00 = vld1q_f16(gemm_out); + float16x8_t s01 = vld1q_f16(gemm_out + 8); + float16x8_t s02 = vld1q_f16(gemm_out + 16); + float16x8_t s03 = vld1q_f16(gemm_out + 24); + float16x8_t s04 = vld1q_f16(gemm_out + 32); + float16x8_t s05 = vld1q_f16(gemm_out + 40); + + float16x8_t s10 = vld1q_f16(gemm_out + 48); + float16x8_t s11 = vld1q_f16(gemm_out + 56); + float16x8_t s12 = vld1q_f16(gemm_out + 64); + float16x8_t s13 = vld1q_f16(gemm_out + 72); + float16x8_t s14 = vld1q_f16(gemm_out + 80); + float16x8_t s15 = vld1q_f16(gemm_out + 88); + + float16x8_t s20 = vld1q_f16(gemm_out + 96); + float16x8_t s21 = vld1q_f16(gemm_out + 104); + float16x8_t s22 = vld1q_f16(gemm_out + 112); + float16x8_t s23 = vld1q_f16(gemm_out + 120); + float16x8_t s24 = vld1q_f16(gemm_out + 128); + float16x8_t s25 = vld1q_f16(gemm_out + 136); + + float16x8_t s30 = vld1q_f16(gemm_out + 144); + float16x8_t s31 = vld1q_f16(gemm_out + 152); + float16x8_t s32 = vld1q_f16(gemm_out + 160); + float16x8_t s33 = vld1q_f16(gemm_out + 168); + float16x8_t s34 = vld1q_f16(gemm_out + 176); + float16x8_t s35 = vld1q_f16(gemm_out + 184); + + float16x8_t s40 = vld1q_f16(gemm_out + 192); + float16x8_t s41 = vld1q_f16(gemm_out + 200); + float16x8_t s42 = vld1q_f16(gemm_out + 208); + float16x8_t s43 = vld1q_f16(gemm_out + 216); + float16x8_t s44 = vld1q_f16(gemm_out + 224); + float16x8_t s45 = vld1q_f16(gemm_out + 232); + + float16x8_t s50 = vld1q_f16(gemm_out + 240); + float16x8_t s51 = vld1q_f16(gemm_out + 248); + float16x8_t s52 = vld1q_f16(gemm_out + 256); + float16x8_t s53 = vld1q_f16(gemm_out + 264); + float16x8_t s54 = vld1q_f16(gemm_out + 272); + float16x8_t s55 = vld1q_f16(gemm_out + 280); + + float16x8_t t00 = vaddq_f16(vaddq_f16(vaddq_f16(s00, s10), vaddq_f16(s20, s30)), s40); + float16x8_t t01 = vaddq_f16(vaddq_f16(vaddq_f16(s01, s11), vaddq_f16(s21, s31)), s41); + float16x8_t t02 = vaddq_f16(vaddq_f16(vaddq_f16(s02, s12), vaddq_f16(s22, s32)), s42); + float16x8_t t03 = vaddq_f16(vaddq_f16(vaddq_f16(s03, s13), vaddq_f16(s23, s33)), s43); + float16x8_t t04 = vaddq_f16(vaddq_f16(vaddq_f16(s04, s14), vaddq_f16(s24, s34)), s44); + float16x8_t t05 = vaddq_f16(vaddq_f16(vaddq_f16(s05, s15), vaddq_f16(s25, s35)), s45); + + float16x8_t t10 = vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 2)); + float16x8_t t11 = vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 2)); + float16x8_t t12 = vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 2)); + float16x8_t t13 = vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 2)); + float16x8_t t14 = vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 2)); + float16x8_t t15 = vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 2)); + + float16x8_t t20 = vaddq_f16(vaddq_f16(s10, s20), vmulq_n_f16(vaddq_f16(s30, s40), 4)); + float16x8_t t21 = vaddq_f16(vaddq_f16(s11, s21), vmulq_n_f16(vaddq_f16(s31, s41), 4)); + float16x8_t t22 = vaddq_f16(vaddq_f16(s12, s22), vmulq_n_f16(vaddq_f16(s32, s42), 4)); + float16x8_t t23 = vaddq_f16(vaddq_f16(s13, s23), vmulq_n_f16(vaddq_f16(s33, s43), 4)); + float16x8_t t24 = vaddq_f16(vaddq_f16(s14, s24), vmulq_n_f16(vaddq_f16(s34, s44), 4)); + float16x8_t t25 = vaddq_f16(vaddq_f16(s15, s25), vmulq_n_f16(vaddq_f16(s35, s45), 4)); + + float16x8_t t30 = vaddq_f16(vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 8)), s50); + float16x8_t t31 = vaddq_f16(vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 8)), s51); + float16x8_t t32 = vaddq_f16(vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 8)), s52); + float16x8_t t33 = vaddq_f16(vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 8)), s53); + float16x8_t t34 = vaddq_f16(vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 8)), s54); + float16x8_t t35 = vaddq_f16(vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 8)), s55); + + float16x8_t d00 = vaddq_f16(vaddq_f16(vaddq_f16(t00, t01), vaddq_f16(t02, t03)), t04); + float16x8_t d01 = vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 2)); + float16x8_t d02 = vaddq_f16(vaddq_f16(t01, t02), vmulq_n_f16(vaddq_f16(t03, t04), 4)); + float16x8_t d03 = vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 8)), t05); + + float16x8_t d10 = vaddq_f16(vaddq_f16(vaddq_f16(t10, t11), vaddq_f16(t12, t13)), t14); + float16x8_t d11 = vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 2)); + float16x8_t d12 = vaddq_f16(vaddq_f16(t11, t12), vmulq_n_f16(vaddq_f16(t13, t14), 4)); + float16x8_t d13 = vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 8)), t15); + + float16x8_t d20 = vaddq_f16(vaddq_f16(vaddq_f16(t20, t21), vaddq_f16(t22, t23)), t24); + float16x8_t d21 = vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 2)); + float16x8_t d22 = vaddq_f16(vaddq_f16(t21, t22), vmulq_n_f16(vaddq_f16(t23, t24), 4)); + float16x8_t d23 = vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 8)), t25); + + float16x8_t d30 = vaddq_f16(vaddq_f16(vaddq_f16(t30, t31), vaddq_f16(t32, t33)), t34); + float16x8_t d31 = vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 2)); + float16x8_t d32 = vaddq_f16(vaddq_f16(t31, t32), vmulq_n_f16(vaddq_f16(t33, t34), 4)); + float16x8_t d33 = vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 8)), t35); + + vst1q_f16(output_data, d00); + vst1q_f16(output_data + 8, d01); + vst1q_f16(output_data + 16, d02); + vst1q_f16(output_data + 24, d03); + + vst1q_f16(output_data + output_w * 8, d10); + vst1q_f16(output_data + output_w * 8 + 8, d11); + vst1q_f16(output_data + output_w * 8 + 16, d12); + vst1q_f16(output_data + output_w * 8 + 24, d13); + + vst1q_f16(output_data + 2 * output_w * 8, d20); + vst1q_f16(output_data + 2 * output_w * 8 + 8, d21); + vst1q_f16(output_data + 2 * output_w * 8 + 16, d22); + vst1q_f16(output_data + 2 * output_w * 8 + 24, d23); + + vst1q_f16(output_data + 3 * output_w * 8, d30); + vst1q_f16(output_data + 3 * output_w * 8 + 8, d31); + vst1q_f16(output_data + 3 * output_w * 8 + 16, d32); + vst1q_f16(output_data + 3 * output_w * 8 + 24, d33); +} + +void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data, + int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) { + int output_channel = conv_param->output_channel_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int oc8 = UP_DIV(output_channel, C8NUM); + + for (int i = 0; i < real_cal_num; i++) { + int out_w_index = (start_index + i) % out_w_block; + int out_h_index = (start_index + i) / out_w_block; + int src_tile_offset = i * oc8 * C8NUM * 36; + int dst_tile_offset = 8 * (out_w_index * 4 + out_h_index * 4 * output_w); + + for (int j = 0; j < oc8; j++) { + int src_oc8_offset = src_tile_offset + j * 36 * C8NUM; + int dst_oc8_offset = dst_tile_offset + j * C8NUM * output_h * output_w; + const float16_t *src_ptr = gemm_out + src_oc8_offset; + const float16_t *bias_ptr = bias_data + j * C8NUM; + float16_t *dst_ptr = out_data + dst_oc8_offset; + + // output transform + Conv3x3Fp16OutputUnit(src_ptr, bias_ptr, dst_ptr, output_w); + } + } +} +#endif + +// int8 conv3x3 +void Conv3x3Uint8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) { +#ifdef ENABLE_ARM + int16x8_t zp = vdupq_n_s16(input_zp); + + int16x8_t d00 = vsubq_s16(vld1q_s16(tmp_data), zp); + int16x8_t d01 = vsubq_s16(vld1q_s16(tmp_data + 8), zp); + int16x8_t d02 = vsubq_s16(vld1q_s16(tmp_data + 2 * 8), zp); + int16x8_t d03 = vsubq_s16(vld1q_s16(tmp_data + 3 * 8), zp); + + int16x8_t d10 = vsubq_s16(vld1q_s16(tmp_data + 4 * 8), zp); + int16x8_t d11 = vsubq_s16(vld1q_s16(tmp_data + 5 * 8), zp); + int16x8_t d12 = vsubq_s16(vld1q_s16(tmp_data + 6 * 8), zp); + int16x8_t d13 = vsubq_s16(vld1q_s16(tmp_data + 7 * 8), zp); + + int16x8_t d20 = vsubq_s16(vld1q_s16(tmp_data + 8 * 8), zp); + int16x8_t d21 = vsubq_s16(vld1q_s16(tmp_data + 9 * 8), zp); + int16x8_t d22 = vsubq_s16(vld1q_s16(tmp_data + 10 * 8), zp); + int16x8_t d23 = vsubq_s16(vld1q_s16(tmp_data + 11 * 8), zp); + + int16x8_t d30 = vsubq_s16(vld1q_s16(tmp_data + 12 * 8), zp); + int16x8_t d31 = vsubq_s16(vld1q_s16(tmp_data + 13 * 8), zp); + int16x8_t d32 = vsubq_s16(vld1q_s16(tmp_data + 14 * 8), zp); + int16x8_t d33 = vsubq_s16(vld1q_s16(tmp_data + 15 * 8), zp); + + int16x8_t t00 = vsubq_s16(d00, d20); + int16x8_t t01 = vsubq_s16(d01, d21); + int16x8_t t02 = vsubq_s16(d02, d22); + int16x8_t t03 = vsubq_s16(d03, d23); + + int16x8_t t10 = vaddq_s16(d10, d20); + int16x8_t t11 = vaddq_s16(d11, d21); + int16x8_t t12 = vaddq_s16(d12, d22); + int16x8_t t13 = vaddq_s16(d13, d23); + + int16x8_t t20 = vsubq_s16(d20, d10); + int16x8_t t21 = vsubq_s16(d21, d11); + int16x8_t t22 = vsubq_s16(d22, d12); + int16x8_t t23 = vsubq_s16(d23, d13); + + int16x8_t t30 = vsubq_s16(d10, d30); + int16x8_t t31 = vsubq_s16(d11, d31); + int16x8_t t32 = vsubq_s16(d12, d32); + int16x8_t t33 = vsubq_s16(d13, d33); + + int16x8_t m00 = vsubq_s16(t00, t02); + int16x8_t m01 = vaddq_s16(t01, t02); + int16x8_t m02 = vsubq_s16(t02, t01); + int16x8_t m03 = vsubq_s16(t01, t03); + + int16x8_t m10 = vsubq_s16(t10, t12); + int16x8_t m11 = vaddq_s16(t11, t12); + int16x8_t m12 = vsubq_s16(t12, t11); + int16x8_t m13 = vsubq_s16(t11, t13); + + int16x8_t m20 = vsubq_s16(t20, t22); + int16x8_t m21 = vaddq_s16(t21, t22); + int16x8_t m22 = vsubq_s16(t22, t21); + int16x8_t m23 = vsubq_s16(t21, t23); + + int16x8_t m30 = vsubq_s16(t30, t32); + int16x8_t m31 = vaddq_s16(t31, t32); + int16x8_t m32 = vsubq_s16(t32, t31); + int16x8_t m33 = vsubq_s16(t31, t33); + + vst1q_s16(trans_input_data, m00); + vst1q_s16(trans_input_data + step, m01); + vst1q_s16(trans_input_data + 2 * step, m02); + vst1q_s16(trans_input_data + 3 * step, m03); + + vst1q_s16(trans_input_data + 4 * step, m10); + vst1q_s16(trans_input_data + 5 * step, m11); + vst1q_s16(trans_input_data + 6 * step, m12); + vst1q_s16(trans_input_data + 7 * step, m13); + + vst1q_s16(trans_input_data + 8 * step, m20); + vst1q_s16(trans_input_data + 9 * step, m21); + vst1q_s16(trans_input_data + 10 * step, m22); + vst1q_s16(trans_input_data + 11 * step, m23); + + vst1q_s16(trans_input_data + 12 * step, m30); + vst1q_s16(trans_input_data + 13 * step, m31); + vst1q_s16(trans_input_data + 14 * step, m32); + vst1q_s16(trans_input_data + 15 * step, m33); +#else + for (int i = 0; i < C8NUM; i++) { + int16_t *local_ptr = tmp_data + i; + int16_t d00 = local_ptr[0] - input_zp; + int16_t d01 = (local_ptr + C8NUM)[0] - input_zp; + int16_t d02 = (local_ptr + 2 * C8NUM)[0] - input_zp; + int16_t d03 = (local_ptr + 3 * C8NUM)[0] - input_zp; + + int16_t d10 = (local_ptr + 4 * C8NUM)[0] - input_zp; + int16_t d11 = (local_ptr + 5 * C8NUM)[0] - input_zp; + int16_t d12 = (local_ptr + 6 * C8NUM)[0] - input_zp; + int16_t d13 = (local_ptr + 7 * C8NUM)[0] - input_zp; + + int16_t d20 = (local_ptr + 8 * C8NUM)[0] - input_zp; + int16_t d21 = (local_ptr + 9 * C8NUM)[0] - input_zp; + int16_t d22 = (local_ptr + 10 * C8NUM)[0] - input_zp; + int16_t d23 = (local_ptr + 11 * C8NUM)[0] - input_zp; + + int16_t d30 = (local_ptr + 12 * C8NUM)[0] - input_zp; + int16_t d31 = (local_ptr + 13 * C8NUM)[0] - input_zp; + int16_t d32 = (local_ptr + 14 * C8NUM)[0] - input_zp; + int16_t d33 = (local_ptr + 15 * C8NUM)[0] - input_zp; + + int16_t t00 = d00 - d20; + int16_t t01 = d01 - d21; + int16_t t02 = d02 - d22; + int16_t t03 = d03 - d23; + + int16_t t10 = d10 + d20; + int16_t t11 = d11 + d21; + int16_t t12 = d12 + d22; + int16_t t13 = d13 + d23; + + int16_t t20 = d20 - d10; + int16_t t21 = d21 - d11; + int16_t t22 = d22 - d12; + int16_t t23 = d23 - d13; + + int16_t t30 = d10 - d30; + int16_t t31 = d11 - d31; + int16_t t32 = d12 - d32; + int16_t t33 = d13 - d33; + + int16_t m00 = t00 - t02; + int16_t m01 = t01 + t02; + int16_t m02 = t02 - t01; + int16_t m03 = t01 - t03; + + int16_t m10 = t10 - t12; + int16_t m11 = t11 + t12; + int16_t m12 = t12 - t11; + int16_t m13 = t11 - t13; + + int16_t m20 = t20 - t22; + int16_t m21 = t21 + t22; + int16_t m22 = t22 - t21; + int16_t m23 = t21 - t23; + + int16_t m30 = t30 - t32; + int16_t m31 = t31 + t32; + int16_t m32 = t32 - t31; + int16_t m33 = t31 - t33; + + (trans_input_data + i)[0] = m00; + (trans_input_data + i + step)[0] = m01; + (trans_input_data + i + 2 * step)[0] = m02; + (trans_input_data + i + 3 * step)[0] = m03; + + (trans_input_data + i + 4 * step)[0] = m10; + (trans_input_data + i + 5 * step)[0] = m11; + (trans_input_data + i + 6 * step)[0] = m12; + (trans_input_data + i + 7 * step)[0] = m13; + + (trans_input_data + i + 8 * step)[0] = m20; + (trans_input_data + i + 9 * step)[0] = m21; + (trans_input_data + i + 10 * step)[0] = m22; + (trans_input_data + i + 11 * step)[0] = m23; + + (trans_input_data + i + 12 * step)[0] = m30; + (trans_input_data + i + 13 * step)[0] = m31; + (trans_input_data + i + 14 * step)[0] = m32; + (trans_input_data + i + 15 * step)[0] = m33; + } +#endif +} + +void Conv3x3Uint8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param) { + // input data format : nhwc + int input_channel = conv_param->input_channel_; + int input_width = conv_param->input_w_; + int input_height = conv_param->input_h_; + int pad_w = conv_param->pad_w_; + int pad_h = conv_param->pad_h_; + ConvQuantArg quant_arg = conv_param->conv_quant_arg_; + int input_zp = quant_arg.quant_args_[0][0].zp_; + int ic8 = UP_DIV(input_channel, C8NUM); + int input_unit = 4; + + for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { + int x_id = start_index + cal_id; + int origin_x = (x_id % out_w_block) * OUPUT_UNIT - pad_w; + int origin_y = (x_id / out_w_block) * OUPUT_UNIT - pad_h; + int real_x_start = origin_x > 0 ? 0 : -origin_x; + int real_x_end = (origin_x + input_unit) < input_width ? input_unit : (input_width - origin_x); + int real_y_start = origin_y > 0 ? 0 : -origin_y; + int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y); + + int src_plane_offset = C8NUM * (origin_y * input_width + origin_x); + int dst_plane_offset = cal_id * C8NUM; + for (int ic = 0; ic < ic8; ic++) { + // copy data from origin input to tmp buffer + for (int i = 0; i < input_unit * input_unit * TILE_NUM; i++) tmp_data[i] = input_zp; + + int src_c8_offset = src_plane_offset + ic * C8NUM * input_height * input_width; + for (int j = real_y_start; j < real_y_end; j++) { + const int16_t *src = input_data + src_c8_offset + C8NUM * (j * input_width + real_x_start); + int16_t *dst = tmp_data + C8NUM * (C4NUM * j + real_x_start); + memcpy(dst, src, (real_x_end - real_x_start) * C8NUM * sizeof(int16_t)); + } + // input transform + int dst_ic8_offset = dst_plane_offset + ic * TILE_NUM * C8NUM; + size_t dst_step = ic8 * C8NUM * TILE_NUM; + int16_t *trans_input_ptr = trans_input + dst_ic8_offset; + Conv3x3Uint8InputUnit(tmp_data, trans_input_ptr, dst_step, input_zp); + } + } +} + +void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, + int kernel_plane) { + int input_unit = 4; + int dst_step = iC8 * C8NUM * C4NUM; + for (int o = 0; o < output_channel; o++) { + int oc4_block_num = o / C4NUM; + int oc4_block_rem = o % C4NUM; + int src_oc_offset = o * iC8 * C8NUM * kernel_plane; + int dst_oc_offset = oc4_block_num * C4NUM * iC8 * C8NUM * input_unit * input_unit + oc4_block_rem; + for (int i = 0; i < iC8; i++) { + auto src_ic8_ptr = weight_data + src_oc_offset + i * kernel_plane * C8NUM; + auto dst_ic8_ptr = trans_weight + dst_oc_offset + i * C4NUM * C8NUM; +#ifdef ENABLE_ARM + int16x8_t g00 = vld1q_s16(src_ic8_ptr); + int16x8_t g01 = vld1q_s16(src_ic8_ptr + 8); + int16x8_t g02 = vld1q_s16(src_ic8_ptr + 2 * 8); + int16x8_t g10 = vld1q_s16(src_ic8_ptr + 3 * 8); + int16x8_t g11 = vld1q_s16(src_ic8_ptr + 4 * 8); + int16x8_t g12 = vld1q_s16(src_ic8_ptr + 5 * 8); + int16x8_t g20 = vld1q_s16(src_ic8_ptr + 6 * 8); + int16x8_t g21 = vld1q_s16(src_ic8_ptr + 7 * 8); + int16x8_t g22 = vld1q_s16(src_ic8_ptr + 8 * 8); + + int16x8_t dst00 = vmulq_n_s16(g00, 2); + int16x8_t dst01 = vmulq_n_s16(g01, 2); + int16x8_t dst02 = vmulq_n_s16(g02, 2); + + int16x8_t dst10 = vaddq_s16(vaddq_s16(g00, g10), g20); + int16x8_t dst11 = vaddq_s16(vaddq_s16(g01, g11), g21); + int16x8_t dst12 = vaddq_s16(vaddq_s16(g02, g12), g22); + + int16x8_t dst20 = vaddq_s16(vsubq_s16(g00, g10), g20); + int16x8_t dst21 = vaddq_s16(vsubq_s16(g01, g11), g21); + int16x8_t dst22 = vaddq_s16(vsubq_s16(g02, g12), g22); + + int16x8_t dst30 = vmulq_n_s16(g20, 2); + int16x8_t dst31 = vmulq_n_s16(g21, 2); + int16x8_t dst32 = vmulq_n_s16(g22, 2); + + int16x8_t m00 = vmulq_n_s16(dst00, 2); + int16x8_t m01 = vaddq_s16(vaddq_s16(dst00, dst01), dst02); + int16x8_t m02 = vaddq_s16(vsubq_s16(dst00, dst01), dst02); + int16x8_t m03 = vmulq_n_s16(dst02, 2); + + int16x8_t m10 = vmulq_n_s16(dst10, 2); + int16x8_t m11 = vaddq_s16(vaddq_s16(dst10, dst11), dst12); + int16x8_t m12 = vaddq_s16(vsubq_s16(dst10, dst11), dst12); + int16x8_t m13 = vmulq_n_s16(dst12, 2); + + int16x8_t m20 = vmulq_n_s16(dst20, 2); + int16x8_t m21 = vaddq_s16(vaddq_s16(dst20, dst21), dst22); + int16x8_t m22 = vaddq_s16(vsubq_s16(dst20, dst21), dst22); + int16x8_t m23 = vmulq_n_s16(dst22, 2); + + int16x8_t m30 = vmulq_n_s16(dst30, 2); + int16x8_t m31 = vaddq_s16(vaddq_s16(dst30, dst31), dst32); + int16x8_t m32 = vaddq_s16(vsubq_s16(dst30, dst31), dst32); + int16x8_t m33 = vmulq_n_s16(dst32, 2); + + dst_ic8_ptr[0] = m00[0]; + dst_ic8_ptr[4] = m00[1]; + dst_ic8_ptr[8] = m00[2]; + dst_ic8_ptr[12] = m00[3]; + dst_ic8_ptr[16] = m00[4]; + dst_ic8_ptr[20] = m00[5]; + dst_ic8_ptr[24] = m00[6]; + dst_ic8_ptr[28] = m00[7]; + + dst_ic8_ptr[0 + dst_step] = m01[0]; + dst_ic8_ptr[4 + dst_step] = m01[1]; + dst_ic8_ptr[8 + dst_step] = m01[2]; + dst_ic8_ptr[12 + dst_step] = m01[3]; + dst_ic8_ptr[16 + dst_step] = m01[4]; + dst_ic8_ptr[20 + dst_step] = m01[5]; + dst_ic8_ptr[24 + dst_step] = m01[6]; + dst_ic8_ptr[28 + dst_step] = m01[7]; + + dst_ic8_ptr[0 + 2 * dst_step] = m02[0]; + dst_ic8_ptr[4 + 2 * dst_step] = m02[1]; + dst_ic8_ptr[8 + 2 * dst_step] = m02[2]; + dst_ic8_ptr[12 + 2 * dst_step] = m02[3]; + dst_ic8_ptr[16 + 2 * dst_step] = m02[4]; + dst_ic8_ptr[20 + 2 * dst_step] = m02[5]; + dst_ic8_ptr[24 + 2 * dst_step] = m02[6]; + dst_ic8_ptr[28 + 2 * dst_step] = m02[7]; + + dst_ic8_ptr[0 + 3 * dst_step] = m03[0]; + dst_ic8_ptr[4 + 3 * dst_step] = m03[1]; + dst_ic8_ptr[8 + 3 * dst_step] = m03[2]; + dst_ic8_ptr[12 + 3 * dst_step] = m03[3]; + dst_ic8_ptr[16 + 3 * dst_step] = m03[4]; + dst_ic8_ptr[20 + 3 * dst_step] = m03[5]; + dst_ic8_ptr[24 + 3 * dst_step] = m03[6]; + dst_ic8_ptr[28 + 3 * dst_step] = m03[7]; + + dst_ic8_ptr[0 + 4 * dst_step] = m10[0]; + dst_ic8_ptr[4 + 4 * dst_step] = m10[1]; + dst_ic8_ptr[8 + 4 * dst_step] = m10[2]; + dst_ic8_ptr[12 + 4 * dst_step] = m10[3]; + dst_ic8_ptr[16 + 4 * dst_step] = m10[4]; + dst_ic8_ptr[20 + 4 * dst_step] = m10[5]; + dst_ic8_ptr[24 + 4 * dst_step] = m10[6]; + dst_ic8_ptr[28 + 4 * dst_step] = m10[7]; + + dst_ic8_ptr[0 + 5 * dst_step] = m11[0]; + dst_ic8_ptr[4 + 5 * dst_step] = m11[1]; + dst_ic8_ptr[8 + 5 * dst_step] = m11[2]; + dst_ic8_ptr[12 + 5 * dst_step] = m11[3]; + dst_ic8_ptr[16 + 5 * dst_step] = m11[4]; + dst_ic8_ptr[20 + 5 * dst_step] = m11[5]; + dst_ic8_ptr[24 + 5 * dst_step] = m11[6]; + dst_ic8_ptr[28 + 5 * dst_step] = m11[7]; + + dst_ic8_ptr[0 + 6 * dst_step] = m12[0]; + dst_ic8_ptr[4 + 6 * dst_step] = m12[1]; + dst_ic8_ptr[8 + 6 * dst_step] = m12[2]; + dst_ic8_ptr[12 + 6 * dst_step] = m12[3]; + dst_ic8_ptr[16 + 6 * dst_step] = m12[4]; + dst_ic8_ptr[20 + 6 * dst_step] = m12[5]; + dst_ic8_ptr[24 + 6 * dst_step] = m12[6]; + dst_ic8_ptr[28 + 6 * dst_step] = m12[7]; + + dst_ic8_ptr[0 + 7 * dst_step] = m13[0]; + dst_ic8_ptr[4 + 7 * dst_step] = m13[1]; + dst_ic8_ptr[8 + 7 * dst_step] = m13[2]; + dst_ic8_ptr[12 + 7 * dst_step] = m13[3]; + dst_ic8_ptr[16 + 7 * dst_step] = m13[4]; + dst_ic8_ptr[20 + 7 * dst_step] = m13[5]; + dst_ic8_ptr[24 + 7 * dst_step] = m13[6]; + dst_ic8_ptr[28 + 7 * dst_step] = m13[7]; + + dst_ic8_ptr[0 + 8 * dst_step] = m20[0]; + dst_ic8_ptr[4 + 8 * dst_step] = m20[1]; + dst_ic8_ptr[8 + 8 * dst_step] = m20[2]; + dst_ic8_ptr[12 + 8 * dst_step] = m20[3]; + dst_ic8_ptr[16 + 8 * dst_step] = m20[4]; + dst_ic8_ptr[20 + 8 * dst_step] = m20[5]; + dst_ic8_ptr[24 + 8 * dst_step] = m20[6]; + dst_ic8_ptr[28 + 8 * dst_step] = m20[7]; + + dst_ic8_ptr[0 + 9 * dst_step] = m21[0]; + dst_ic8_ptr[4 + 9 * dst_step] = m21[1]; + dst_ic8_ptr[8 + 9 * dst_step] = m21[2]; + dst_ic8_ptr[12 + 9 * dst_step] = m21[3]; + dst_ic8_ptr[16 + 9 * dst_step] = m21[4]; + dst_ic8_ptr[20 + 9 * dst_step] = m21[5]; + dst_ic8_ptr[24 + 9 * dst_step] = m21[6]; + dst_ic8_ptr[28 + 9 * dst_step] = m21[7]; + + dst_ic8_ptr[0 + 10 * dst_step] = m22[0]; + dst_ic8_ptr[4 + 10 * dst_step] = m22[1]; + dst_ic8_ptr[8 + 10 * dst_step] = m22[2]; + dst_ic8_ptr[12 + 10 * dst_step] = m22[3]; + dst_ic8_ptr[16 + 10 * dst_step] = m22[4]; + dst_ic8_ptr[20 + 10 * dst_step] = m22[5]; + dst_ic8_ptr[24 + 10 * dst_step] = m22[6]; + dst_ic8_ptr[28 + 10 * dst_step] = m22[7]; + + dst_ic8_ptr[0 + 11 * dst_step] = m23[0]; + dst_ic8_ptr[4 + 11 * dst_step] = m23[1]; + dst_ic8_ptr[8 + 11 * dst_step] = m23[2]; + dst_ic8_ptr[12 + 11 * dst_step] = m23[3]; + dst_ic8_ptr[16 + 11 * dst_step] = m23[4]; + dst_ic8_ptr[20 + 11 * dst_step] = m23[5]; + dst_ic8_ptr[24 + 11 * dst_step] = m23[6]; + dst_ic8_ptr[28 + 11 * dst_step] = m23[7]; + + dst_ic8_ptr[0 + 12 * dst_step] = m30[0]; + dst_ic8_ptr[4 + 12 * dst_step] = m30[1]; + dst_ic8_ptr[8 + 12 * dst_step] = m30[2]; + dst_ic8_ptr[12 + 12 * dst_step] = m30[3]; + dst_ic8_ptr[16 + 12 * dst_step] = m30[4]; + dst_ic8_ptr[20 + 12 * dst_step] = m30[5]; + dst_ic8_ptr[24 + 12 * dst_step] = m30[6]; + dst_ic8_ptr[28 + 12 * dst_step] = m30[7]; + + dst_ic8_ptr[0 + 13 * dst_step] = m31[0]; + dst_ic8_ptr[4 + 13 * dst_step] = m31[1]; + dst_ic8_ptr[8 + 13 * dst_step] = m31[2]; + dst_ic8_ptr[12 + 13 * dst_step] = m31[3]; + dst_ic8_ptr[16 + 13 * dst_step] = m31[4]; + dst_ic8_ptr[20 + 13 * dst_step] = m31[5]; + dst_ic8_ptr[24 + 13 * dst_step] = m31[6]; + dst_ic8_ptr[28 + 13 * dst_step] = m31[7]; + + dst_ic8_ptr[0 + 14 * dst_step] = m32[0]; + dst_ic8_ptr[4 + 14 * dst_step] = m32[1]; + dst_ic8_ptr[8 + 14 * dst_step] = m32[2]; + dst_ic8_ptr[12 + 14 * dst_step] = m32[3]; + dst_ic8_ptr[16 + 14 * dst_step] = m32[4]; + dst_ic8_ptr[20 + 14 * dst_step] = m32[5]; + dst_ic8_ptr[24 + 14 * dst_step] = m32[6]; + dst_ic8_ptr[28 + 14 * dst_step] = m32[7]; + + dst_ic8_ptr[0 + 15 * dst_step] = m33[0]; + dst_ic8_ptr[4 + 15 * dst_step] = m33[1]; + dst_ic8_ptr[8 + 15 * dst_step] = m33[2]; + dst_ic8_ptr[12 + 15 * dst_step] = m33[3]; + dst_ic8_ptr[16 + 15 * dst_step] = m33[4]; + dst_ic8_ptr[20 + 15 * dst_step] = m33[5]; + dst_ic8_ptr[24 + 15 * dst_step] = m33[6]; + dst_ic8_ptr[28 + 15 * dst_step] = m33[7]; +#else + for (int j = 0; j < C8NUM; j++) { + auto local_ptr = src_ic8_ptr + j; + int16_t dst00 = local_ptr[0] * 2; + int16_t dst01 = (local_ptr + 8)[0] * 2; + int16_t dst02 = (local_ptr + 16)[0] * 2; + + int16_t dst10 = local_ptr[0] + (local_ptr + 24)[0] + (local_ptr + 48)[0]; + int16_t dst11 = (local_ptr + 8)[0] + (local_ptr + 32)[0] + (local_ptr + 56)[0]; + int16_t dst12 = (local_ptr + 16)[0] + (local_ptr + 40)[0] + (local_ptr + 64)[0]; + + int16_t dst20 = local_ptr[0] - (local_ptr + 24)[0] + (local_ptr + 48)[0]; + int16_t dst21 = (local_ptr + 8)[0] - (local_ptr + 32)[0] + (local_ptr + 56)[0]; + int16_t dst22 = (local_ptr + 16)[0] - (local_ptr + 40)[0] + (local_ptr + 64)[0]; + + int16_t dst30 = (local_ptr + 48)[0] * 2; + int16_t dst31 = (local_ptr + 56)[0] * 2; + int16_t dst32 = (local_ptr + 64)[0] * 2; + + int16_t m00 = dst00 * 2; + int16_t m01 = dst00 + dst01 + dst02; + int16_t m02 = dst00 - dst01 + dst02; + int16_t m03 = dst02 * 2; + + int16_t m10 = dst10 * 2; + int16_t m11 = dst10 + dst11 + dst12; + int16_t m12 = dst10 - dst11 + dst12; + int16_t m13 = dst12 * 2; + + int16_t m20 = dst20 * 2; + int16_t m21 = dst20 + dst21 + dst22; + int16_t m22 = dst20 - dst21 + dst22; + int16_t m23 = dst22 * 2; + + int16_t m30 = dst30 * 2; + int16_t m31 = dst30 + dst31 + dst32; + int16_t m32 = dst30 - dst31 + dst32; + int16_t m33 = dst32 * 2; + + *(dst_ic8_ptr + j * 4) = m00; + *(dst_ic8_ptr + j * 4 + dst_step) = m01; + *(dst_ic8_ptr + j * 4 + 2 * dst_step) = m02; + *(dst_ic8_ptr + j * 4 + 3 * dst_step) = m03; + + *(dst_ic8_ptr + j * 4 + 4 * dst_step) = m10; + *(dst_ic8_ptr + j * 4 + 5 * dst_step) = m11; + *(dst_ic8_ptr + j * 4 + 6 * dst_step) = m12; + *(dst_ic8_ptr + j * 4 + 7 * dst_step) = m13; + + *(dst_ic8_ptr + j * 4 + 8 * dst_step) = m20; + *(dst_ic8_ptr + j * 4 + 9 * dst_step) = m21; + *(dst_ic8_ptr + j * 4 + 10 * dst_step) = m22; + *(dst_ic8_ptr + j * 4 + 11 * dst_step) = m23; + + *(dst_ic8_ptr + j * 4 + 12 * dst_step) = m30; + *(dst_ic8_ptr + j * 4 + 13 * dst_step) = m31; + *(dst_ic8_ptr + j * 4 + 14 * dst_step) = m32; + *(dst_ic8_ptr + j * 4 + 15 * dst_step) = m33; + } +#endif + } + } +} + +void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, + bool w_not_bound, int output_w, int real_num, ConvParameter *conv_param) { + int left_shift = conv_param->conv_quant_arg_.left_shift_[0]; + int right_shift = conv_param->conv_quant_arg_.right_shift_[0]; + int quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0]; + int output_zp = conv_param->conv_quant_arg_.quant_args_[2][0].zp_; + int out_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int out_max = conv_param->conv_quant_arg_.out_act_max_[0]; + +#ifdef ENABLE_ARM + int32x4_t bias_ptr = vld1q_s32(bias_data); + + int32x4_t s00 = vld1q_s32(gemm_out); + int32x4_t s01 = vld1q_s32(gemm_out + 4); + int32x4_t s02 = vld1q_s32(gemm_out + 8); + int32x4_t s03 = vld1q_s32(gemm_out + 12); + + int32x4_t s10 = vld1q_s32(gemm_out + 16); + int32x4_t s11 = vld1q_s32(gemm_out + 20); + int32x4_t s12 = vld1q_s32(gemm_out + 24); + int32x4_t s13 = vld1q_s32(gemm_out + 28); + + int32x4_t s20 = vld1q_s32(gemm_out + 32); + int32x4_t s21 = vld1q_s32(gemm_out + 36); + int32x4_t s22 = vld1q_s32(gemm_out + 40); + int32x4_t s23 = vld1q_s32(gemm_out + 44); + + int32x4_t s30 = vld1q_s32(gemm_out + 48); + int32x4_t s31 = vld1q_s32(gemm_out + 52); + int32x4_t s32 = vld1q_s32(gemm_out + 56); + int32x4_t s33 = vld1q_s32(gemm_out + 60); + + int32x4_t t00 = vshrq_n_s32(vaddq_s32(vaddq_s32(s00, s10), s20), 1); + int32x4_t t01 = vshrq_n_s32(vaddq_s32(vaddq_s32(s01, s11), s21), 1); + int32x4_t t02 = vshrq_n_s32(vaddq_s32(vaddq_s32(s02, s12), s22), 1); + int32x4_t t03 = vshrq_n_s32(vaddq_s32(vaddq_s32(s03, s13), s23), 1); + + int32x4_t t10 = vshrq_n_s32(vsubq_s32(vsubq_s32(s10, s20), s30), 1); + int32x4_t t11 = vshrq_n_s32(vsubq_s32(vsubq_s32(s11, s21), s31), 1); + int32x4_t t12 = vshrq_n_s32(vsubq_s32(vsubq_s32(s12, s22), s32), 1); + int32x4_t t13 = vshrq_n_s32(vsubq_s32(vsubq_s32(s13, s23), s33), 1); + + int32x4_t d00 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t00, t01), t02), 1), bias_ptr); + int32x4_t d01 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t01, t02), t03), 1), bias_ptr); + + int32x4_t d10 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t10, t11), t12), 1), bias_ptr); + int32x4_t d11 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t11, t12), t13), 1), bias_ptr); + + int32x4_t out_multiplier = vdupq_n_s32(quant_multiplier); + int32x4_t out_zp = vdupq_n_s32(output_zp); + int32x4_t output_min = vdupq_n_s32(out_min); + int32x4_t output_max = vdupq_n_s32(out_max); + int32x4_t ls = vdupq_n_s32(left_shift); + int32x4_t rs = vdupq_n_s32(right_shift); + + d00 = vqshlq_s32(d00, ls); + d00 = vqrdmulhq_s32(d00, out_multiplier); + d00 = vqrshlq_s32(d00, rs); + d00 = vaddq_s32(d00, out_zp); + d00 = vmaxq_s32(d00, output_min); + d00 = vminq_s32(d00, output_max); + + d01 = vqshlq_s32(d01, ls); + d01 = vqrdmulhq_s32(d01, out_multiplier); + d01 = vqrshlq_s32(d01, rs); + d01 = vaddq_s32(d01, out_zp); + d01 = vmaxq_s32(d01, output_min); + d01 = vminq_s32(d01, output_max); + + d10 = vqshlq_s32(d10, ls); + d10 = vqrdmulhq_s32(d10, out_multiplier); + d10 = vqrshlq_s32(d10, rs); + d10 = vaddq_s32(d10, out_zp); + d10 = vmaxq_s32(d10, output_min); + d10 = vminq_s32(d10, output_max); + + d11 = vqshlq_s32(d11, ls); + d11 = vqrdmulhq_s32(d11, out_multiplier); + d11 = vqrshlq_s32(d11, rs); + d11 = vaddq_s32(d11, out_zp); + d11 = vmaxq_s32(d11, output_min); + d11 = vminq_s32(d11, output_max); + + (output_data)[0] = (uint8_t)d00[0]; + (output_data + 1)[0] = (uint8_t)d00[1]; + (output_data + 2)[0] = (uint8_t)d00[2]; + (output_data + 3)[0] = (uint8_t)d00[3]; + + if (w_not_bound) { + *(output_data + 4) = (uint8_t)d01[0]; + *(output_data + 5) = (uint8_t)d01[1]; + *(output_data + 6) = (uint8_t)d01[2]; + *(output_data + 7) = (uint8_t)d01[3]; + } + if (h_not_bound) { + *(output_data + output_w * 4) = (uint8_t)d10[0]; + *(output_data + output_w * 4 + 1) = (uint8_t)d10[1]; + *(output_data + output_w * 4 + 2) = (uint8_t)d10[2]; + *(output_data + output_w * 4 + 3) = (uint8_t)d10[3]; + if (w_not_bound) { + *(output_data + output_w * 4 + 4) = (uint8_t)d11[0]; + *(output_data + output_w * 4 + 5) = (uint8_t)d11[1]; + *(output_data + output_w * 4 + 6) = (uint8_t)d11[2]; + *(output_data + output_w * 4 + 7) = (uint8_t)d11[3]; + } + } +#else + for (int i = 0; i < C4NUM; i++) { + const int32_t *local_ptr = gemm_out + i; + const int32_t *bias_ptr = bias_data + i; + + int32_t s00 = local_ptr[0]; + int32_t s01 = (local_ptr + 4)[0]; + int32_t s02 = (local_ptr + 8)[0]; + int32_t s03 = (local_ptr + 12)[0]; + + int32_t s10 = (local_ptr + 16)[0]; + int32_t s11 = (local_ptr + 20)[0]; + int32_t s12 = (local_ptr + 24)[0]; + int32_t s13 = (local_ptr + 28)[0]; + + int32_t s20 = (local_ptr + 32)[0]; + int32_t s21 = (local_ptr + 36)[0]; + int32_t s22 = (local_ptr + 40)[0]; + int32_t s23 = (local_ptr + 44)[0]; + + int32_t s30 = (local_ptr + 48)[0]; + int32_t s31 = (local_ptr + 52)[0]; + int32_t s32 = (local_ptr + 56)[0]; + int32_t s33 = (local_ptr + 60)[0]; + + int32_t t00 = (s00 + s10 + s20) / 2; + int32_t t01 = (s01 + s11 + s21) / 2; + int32_t t02 = (s02 + s12 + s22) / 2; + int32_t t03 = (s03 + s13 + s23) / 2; + + int32_t t10 = (s10 - s20 - s30) / 2; + int32_t t11 = (s11 - s21 - s31) / 2; + int32_t t12 = (s12 - s22 - s32) / 2; + int32_t t13 = (s13 - s23 - s33) / 2; + + int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; + int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; + + int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; + int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; + + d00 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift), quant_multiplier), -right_shift); + d00 += output_zp; + d00 = d00 > out_min ? d00 : out_min; + d00 = d00 < out_max ? d00 : out_max; + + d01 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift), quant_multiplier), -right_shift); + d01 += output_zp; + d01 = d01 > out_min ? d01 : out_min; + d01 = d01 < out_max ? d01 : out_max; + + d10 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift), quant_multiplier), -right_shift); + d10 += output_zp; + d10 = d10 > out_min ? d10 : out_min; + d10 = d10 < out_max ? d10 : out_max; + + d11 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift), quant_multiplier), -right_shift); + d11 += output_zp; + d11 = d11 > out_min ? d11 : out_min; + d11 = d11 < out_max ? d11 : out_max; + + (output_data + i)[0] = (int8_t)d00; + if (w_not_bound) { + (output_data + i + C4NUM)[0] = (int8_t)d01; + } + if (h_not_bound) { + (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; + if (w_not_bound) { + (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; + } + } + } +#endif +} + +void Conv3x3Uint8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param) { + int output_channel = conv_param->output_channel_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int oc4 = UP_DIV(output_channel, C4NUM); + int input_unit = 4; + + for (int i = 0; i < real_cal_num; i++) { + int out_w_index = (start_index + i) % out_w_block; + int out_h_index = (start_index + i) / out_w_block; + int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit; + int dst_tile_offset = C4NUM * (out_w_index * OUPUT_UNIT + out_h_index * OUPUT_UNIT * output_w); + + for (int j = 0; j < oc4; j++) { + int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM; + int dst_oc4_offset = dst_tile_offset + j * C4NUM * output_h * output_w; + const int32_t *src_ptr = gemm_out + src_oc4_offset; + const int32_t *bias_ptr = bias_data + j * C4NUM; + int8_t *dst_ptr = out_data + dst_oc4_offset; + + // output transform + int real_num = (output_channel - j * C4NUM) < C4NUM ? (output_channel - j * C4NUM) : C4NUM; + bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w; + bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h; + Conv3x3Uint8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, conv_param); + } + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_transform.h b/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_transform.h new file mode 100644 index 00000000000..d0cc7e1b330 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_transform.h @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_WINOGRAD_TRANSFORM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_WINOGRAD_TRANSFORM_H_ + +#ifdef ENABLE_ARM +#include +#endif +#include +#include "src/runtime/kernel/arm/opclib/pack.h" +#include "src/runtime/kernel/arm/opclib/fp32/conv.h" +#include "src/runtime/kernel/arm/opclib/winograd_utils.h" +#include "src/runtime/kernel/arm/opclib/quantization/fixed_point.h" + +#define OUPUT_UNIT 2 + +// for fp32 winograd input/output transform +void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, ConvParameter *conv_param, + InputTransformUnitFunc input_trans_func); + +void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const float *bias_data, int cal_num, + int out_tile_index, int output_unit_num, ConvParameter *conv_param, + OutputTransformUnitFunc output_trans_func); + +// for fp32 convolution 3x3 filter/input/output transform +void Conv3x3Fp32InputUnit(const float *tmp_data, float *trans_input_data, size_t step); + +void Conv3x3Fp32InputTransform(const float *input_data, float *trans_input, float *tmp_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param); + +void Conv3x3Fp32FilterTransform(float *weight_data, float *trans_weight, int iC4, int output_channel, int kernel_plane); + +void Conv3x3Fp32OutputUnit(const float *gemm_out, const float *bias_data, float *output_data, bool h_not_bound, + bool w_not_bound, int output_w); + +void Conv3x3Fp32OutputTransform(const float *gemm_out, float *out_data, const float *bias_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param); + +#ifdef ENABLE_FP16 +// for fp16 convolution 3x3 filter/input/output transform +void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step); + +void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, + int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param); + +void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC8, int output_channel, + int kernel_plane); + +void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, int output_w); + +void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data, + int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param); +#endif + +// for int8 convolution 3x3 filter/input/output transform +void Conv3x3Uint8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp); + +void Conv3x3Uint8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param); + +void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, + int kernel_plane); + +void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, + bool w_not_bound, int output_w, int real_num, ConvParameter *conv_param); + +void Conv3x3Uint8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, + int real_cal_num, int out_w_block, ConvParameter *conv_param); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_WINOGRAD_TRANSFORM_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_utils.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_utils.cc new file mode 100644 index 00000000000..7e611a27281 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_utils.cc @@ -0,0 +1,3804 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/winograd_utils.h" +#include + +#define MIN_UNIT 2 +#define MAX_UNIT 8 + +static OutputTransformUnitFunc outputTransformUnit[] = { + nullptr, // 0 + nullptr, // 1 + OutputTransform8x2Unit, + OutputTransform8x3Unit, + OutputTransform8x4Unit, + OutputTransform8x5Unit, + OutputTransform8x6Unit, + OutputTransform8x7Unit, +}; + +void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step) { +#ifdef ENABLE_ARM + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 15 * src_step); + + float32x4_t t00 = vsubq_f32(src_data_00, vmulq_n_f32(src_data_20, 4)); + float32x4_t t01 = vsubq_f32(src_data_01, vmulq_n_f32(src_data_21, 4)); + float32x4_t t02 = vsubq_f32(src_data_02, vmulq_n_f32(src_data_22, 4)); + float32x4_t t03 = vsubq_f32(src_data_03, vmulq_n_f32(src_data_23, 4)); + + float32x4_t t10 = vaddq_f32(src_data_10, vmulq_n_f32(src_data_20, 2)); + float32x4_t t11 = vaddq_f32(src_data_11, vmulq_n_f32(src_data_21, 2)); + float32x4_t t12 = vaddq_f32(src_data_12, vmulq_n_f32(src_data_22, 2)); + float32x4_t t13 = vaddq_f32(src_data_13, vmulq_n_f32(src_data_23, 2)); + + float32x4_t t20 = vsubq_f32(vmulq_n_f32(src_data_20, 2), src_data_10); + float32x4_t t21 = vsubq_f32(vmulq_n_f32(src_data_21, 2), src_data_11); + float32x4_t t22 = vsubq_f32(vmulq_n_f32(src_data_22, 2), src_data_12); + float32x4_t t23 = vsubq_f32(vmulq_n_f32(src_data_23, 2), src_data_13); + + float32x4_t t30 = vsubq_f32(src_data_30, vmulq_n_f32(src_data_10, 0.25)); + float32x4_t t31 = vsubq_f32(src_data_31, vmulq_n_f32(src_data_11, 0.25)); + float32x4_t t32 = vsubq_f32(src_data_32, vmulq_n_f32(src_data_12, 0.25)); + float32x4_t t33 = vsubq_f32(src_data_33, vmulq_n_f32(src_data_13, 0.25)); + + float32x4_t m00 = vsubq_f32(t00, vmulq_n_f32(t02, 4)); + float32x4_t m01 = vaddq_f32(t01, vmulq_n_f32(t02, 2)); + float32x4_t m02 = vsubq_f32(vmulq_n_f32(t02, 2), t01); + float32x4_t m03 = vsubq_f32(t03, vmulq_n_f32(t01, 0.25)); + + float32x4_t m10 = vsubq_f32(t10, vmulq_n_f32(t12, 4)); + float32x4_t m11 = vaddq_f32(t11, vmulq_n_f32(t12, 2)); + float32x4_t m12 = vsubq_f32(vmulq_n_f32(t12, 2), t11); + float32x4_t m13 = vsubq_f32(t13, vmulq_n_f32(t11, 0.25)); + + float32x4_t m20 = vsubq_f32(t20, vmulq_n_f32(t22, 4)); + float32x4_t m21 = vaddq_f32(t21, vmulq_n_f32(t22, 2)); + float32x4_t m22 = vsubq_f32(vmulq_n_f32(t22, 2), t21); + float32x4_t m23 = vsubq_f32(t23, vmulq_n_f32(t21, 0.25)); + + float32x4_t m30 = vsubq_f32(t30, vmulq_n_f32(t32, 4)); + float32x4_t m31 = vaddq_f32(t31, vmulq_n_f32(t32, 2)); + float32x4_t m32 = vsubq_f32(vmulq_n_f32(t32, 2), t31); + float32x4_t m33 = vsubq_f32(t33, vmulq_n_f32(t31, 0.25)); + + vst1q_f32(dst_data + 0 * dst_step, m00); + vst1q_f32(dst_data + 1 * dst_step, m01); + vst1q_f32(dst_data + 2 * dst_step, m02); + vst1q_f32(dst_data + 3 * dst_step, m03); + vst1q_f32(dst_data + 4 * dst_step, m10); + vst1q_f32(dst_data + 5 * dst_step, m11); + vst1q_f32(dst_data + 6 * dst_step, m12); + vst1q_f32(dst_data + 7 * dst_step, m13); + vst1q_f32(dst_data + 8 * dst_step, m20); + vst1q_f32(dst_data + 9 * dst_step, m21); + vst1q_f32(dst_data + 10 * dst_step, m22); + vst1q_f32(dst_data + 11 * dst_step, m23); + vst1q_f32(dst_data + 12 * dst_step, m30); + vst1q_f32(dst_data + 13 * dst_step, m31); + vst1q_f32(dst_data + 14 * dst_step, m32); + vst1q_f32(dst_data + 15 * dst_step, m33); +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_10 = src_data[i + 4 * src_step]; + float src_data_11 = src_data[i + 5 * src_step]; + float src_data_12 = src_data[i + 6 * src_step]; + float src_data_13 = src_data[i + 7 * src_step]; + float src_data_20 = src_data[i + 8 * src_step]; + float src_data_21 = src_data[i + 9 * src_step]; + float src_data_22 = src_data[i + 10 * src_step]; + float src_data_23 = src_data[i + 11 * src_step]; + float src_data_30 = src_data[i + 12 * src_step]; + float src_data_31 = src_data[i + 13 * src_step]; + float src_data_32 = src_data[i + 14 * src_step]; + float src_data_33 = src_data[i + 15 * src_step]; + + float t00 = src_data_00 - 4 * src_data_20; + float t01 = src_data_01 - 4 * src_data_21; + float t02 = src_data_02 - 4 * src_data_22; + float t03 = src_data_03 - 4 * src_data_23; + + float t10 = src_data_10 + 2 * src_data_20; + float t11 = src_data_11 + 2 * src_data_21; + float t12 = src_data_12 + 2 * src_data_22; + float t13 = src_data_13 + 2 * src_data_23; + + float t20 = 2 * src_data_20 - src_data_10; + float t21 = 2 * src_data_21 - src_data_11; + float t22 = 2 * src_data_22 - src_data_12; + float t23 = 2 * src_data_23 - src_data_13; + + float t30 = src_data_30 - 0.25f * src_data_10; + float t31 = src_data_31 - 0.25f * src_data_11; + float t32 = src_data_32 - 0.25f * src_data_12; + float t33 = src_data_33 - 0.25f * src_data_13; + + float m00 = t00 - 4 * t02; + float m01 = t01 + 2 * t02; + float m02 = 2 * t02 - t01; + float m03 = t03 - 0.25f * t01; + + float m10 = t10 - 4 * t12; + float m11 = t11 + 2 * t12; + float m12 = 2 * t12 - t11; + float m13 = t13 - 0.25f * t11; + + float m20 = t20 - 4 * t22; + float m21 = t21 + 2 * t22; + float m22 = 2 * t22 - t21; + float m23 = t23 - 0.25f * t21; + + float m30 = t30 - 4 * t32; + float m31 = t31 + 2 * t32; + float m32 = 2 * t32 - t31; + float m33 = t33 - 0.25f * t31; + + (dst_data + i)[0] = m00; + (dst_data + i + dst_step)[0] = m01; + (dst_data + i + 2 * dst_step)[0] = m02; + (dst_data + i + 3 * dst_step)[0] = m03; + + (dst_data + i + 4 * dst_step)[0] = m10; + (dst_data + i + 5 * dst_step)[0] = m11; + (dst_data + i + 6 * dst_step)[0] = m12; + (dst_data + i + 7 * dst_step)[0] = m13; + + (dst_data + i + 8 * dst_step)[0] = m20; + (dst_data + i + 9 * dst_step)[0] = m21; + (dst_data + i + 10 * dst_step)[0] = m22; + (dst_data + i + 11 * dst_step)[0] = m23; + + (dst_data + i + 12 * dst_step)[0] = m30; + (dst_data + i + 13 * dst_step)[0] = m31; + (dst_data + i + 14 * dst_step)[0] = m32; + (dst_data + i + 15 * dst_step)[0] = m33; + } +#endif +} + +void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step) { +#ifdef ENABLE_ARM + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_04 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_05 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_06 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_07 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_14 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_15 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_16 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_17 = vld1q_f32(src_data + 15 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 16 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 17 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 18 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 19 * src_step); + float32x4_t src_data_24 = vld1q_f32(src_data + 20 * src_step); + float32x4_t src_data_25 = vld1q_f32(src_data + 21 * src_step); + float32x4_t src_data_26 = vld1q_f32(src_data + 22 * src_step); + float32x4_t src_data_27 = vld1q_f32(src_data + 23 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 24 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 25 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 26 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 27 * src_step); + float32x4_t src_data_34 = vld1q_f32(src_data + 28 * src_step); + float32x4_t src_data_35 = vld1q_f32(src_data + 29 * src_step); + float32x4_t src_data_36 = vld1q_f32(src_data + 30 * src_step); + float32x4_t src_data_37 = vld1q_f32(src_data + 31 * src_step); + float32x4_t src_data_40 = vld1q_f32(src_data + 32 * src_step); + float32x4_t src_data_41 = vld1q_f32(src_data + 33 * src_step); + float32x4_t src_data_42 = vld1q_f32(src_data + 34 * src_step); + float32x4_t src_data_43 = vld1q_f32(src_data + 35 * src_step); + float32x4_t src_data_44 = vld1q_f32(src_data + 36 * src_step); + float32x4_t src_data_45 = vld1q_f32(src_data + 37 * src_step); + float32x4_t src_data_46 = vld1q_f32(src_data + 38 * src_step); + float32x4_t src_data_47 = vld1q_f32(src_data + 39 * src_step); + float32x4_t src_data_50 = vld1q_f32(src_data + 40 * src_step); + float32x4_t src_data_51 = vld1q_f32(src_data + 41 * src_step); + float32x4_t src_data_52 = vld1q_f32(src_data + 42 * src_step); + float32x4_t src_data_53 = vld1q_f32(src_data + 43 * src_step); + float32x4_t src_data_54 = vld1q_f32(src_data + 44 * src_step); + float32x4_t src_data_55 = vld1q_f32(src_data + 45 * src_step); + float32x4_t src_data_56 = vld1q_f32(src_data + 46 * src_step); + float32x4_t src_data_57 = vld1q_f32(src_data + 47 * src_step); + float32x4_t src_data_60 = vld1q_f32(src_data + 48 * src_step); + float32x4_t src_data_61 = vld1q_f32(src_data + 49 * src_step); + float32x4_t src_data_62 = vld1q_f32(src_data + 50 * src_step); + float32x4_t src_data_63 = vld1q_f32(src_data + 51 * src_step); + float32x4_t src_data_64 = vld1q_f32(src_data + 52 * src_step); + float32x4_t src_data_65 = vld1q_f32(src_data + 53 * src_step); + float32x4_t src_data_66 = vld1q_f32(src_data + 54 * src_step); + float32x4_t src_data_67 = vld1q_f32(src_data + 55 * src_step); + float32x4_t src_data_70 = vld1q_f32(src_data + 56 * src_step); + float32x4_t src_data_71 = vld1q_f32(src_data + 57 * src_step); + float32x4_t src_data_72 = vld1q_f32(src_data + 58 * src_step); + float32x4_t src_data_73 = vld1q_f32(src_data + 59 * src_step); + float32x4_t src_data_74 = vld1q_f32(src_data + 60 * src_step); + float32x4_t src_data_75 = vld1q_f32(src_data + 61 * src_step); + float32x4_t src_data_76 = vld1q_f32(src_data + 62 * src_step); + float32x4_t src_data_77 = vld1q_f32(src_data + 63 * src_step); + + float32x4_t t00 = vsubq_f32(vaddq_f32(vsubq_f32(src_data_00, vmulq_n_f32(src_data_20, 5.44444444444444444444444445)), + vmulq_n_f32(src_data_40, 6.222222222222)), + vmulq_n_f32(src_data_60, 1.7777777777777)); + float32x4_t t01 = vsubq_f32(vaddq_f32(vsubq_f32(src_data_01, vmulq_n_f32(src_data_21, 5.44444444444444444444444445)), + vmulq_n_f32(src_data_41, 6.222222222222)), + vmulq_n_f32(src_data_61, 1.7777777777777)); + float32x4_t t02 = vsubq_f32(vaddq_f32(vsubq_f32(src_data_02, vmulq_n_f32(src_data_22, 5.44444444444444444444444445)), + vmulq_n_f32(src_data_42, 6.222222222222)), + vmulq_n_f32(src_data_62, 1.7777777777777)); + float32x4_t t03 = vsubq_f32(vaddq_f32(vsubq_f32(src_data_03, vmulq_n_f32(src_data_23, 5.44444444444444444444444445)), + vmulq_n_f32(src_data_43, 6.222222222222)), + vmulq_n_f32(src_data_63, 1.7777777777777)); + float32x4_t t04 = vsubq_f32(vaddq_f32(vsubq_f32(src_data_04, vmulq_n_f32(src_data_24, 5.44444444444444444444444445)), + vmulq_n_f32(src_data_44, 6.222222222222)), + vmulq_n_f32(src_data_64, 1.7777777777777)); + float32x4_t t05 = vsubq_f32(vaddq_f32(vsubq_f32(src_data_05, vmulq_n_f32(src_data_25, 5.44444444444444444444444445)), + vmulq_n_f32(src_data_45, 6.222222222222)), + vmulq_n_f32(src_data_65, 1.7777777777777)); + float32x4_t t06 = vsubq_f32(vaddq_f32(vsubq_f32(src_data_06, vmulq_n_f32(src_data_26, 5.44444444444444444444444445)), + vmulq_n_f32(src_data_46, 6.222222222222)), + vmulq_n_f32(src_data_66, 1.7777777777777)); + float32x4_t t07 = vsubq_f32(vaddq_f32(vsubq_f32(src_data_07, vmulq_n_f32(src_data_27, 5.44444444444444444444444445)), + vmulq_n_f32(src_data_47, 6.222222222222)), + vmulq_n_f32(src_data_67, 1.7777777777777)); + + float32x4_t t10 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_10, 1.5), vmulq_n_f32(src_data_20, 3)), + vmulq_n_f32(src_data_30, 2.166666666666666667)), + vmulq_n_f32(src_data_40, 4.333333333333)), + vmulq_n_f32(src_data_50, 0.66666666666)), + vmulq_n_f32(src_data_60, 1.333333333333)); + float32x4_t t11 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_11, 1.5), vmulq_n_f32(src_data_21, 3)), + vmulq_n_f32(src_data_31, 2.166666666666666667)), + vmulq_n_f32(src_data_41, 4.333333333333)), + vmulq_n_f32(src_data_51, 0.66666666666)), + vmulq_n_f32(src_data_61, 1.333333333333)); + float32x4_t t12 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_12, 1.5), vmulq_n_f32(src_data_22, 3)), + vmulq_n_f32(src_data_32, 2.166666666666666667)), + vmulq_n_f32(src_data_42, 4.333333333333)), + vmulq_n_f32(src_data_52, 0.66666666666)), + vmulq_n_f32(src_data_62, 1.333333333333)); + float32x4_t t13 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_13, 1.5), vmulq_n_f32(src_data_23, 3)), + vmulq_n_f32(src_data_33, 2.166666666666666667)), + vmulq_n_f32(src_data_43, 4.333333333333)), + vmulq_n_f32(src_data_53, 0.66666666666)), + vmulq_n_f32(src_data_63, 1.333333333333)); + float32x4_t t14 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_14, 1.5), vmulq_n_f32(src_data_24, 3)), + vmulq_n_f32(src_data_34, 2.166666666666666667)), + vmulq_n_f32(src_data_44, 4.333333333333)), + vmulq_n_f32(src_data_54, 0.66666666666)), + vmulq_n_f32(src_data_64, 1.333333333333)); + float32x4_t t15 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_15, 1.5), vmulq_n_f32(src_data_25, 3)), + vmulq_n_f32(src_data_35, 2.166666666666666667)), + vmulq_n_f32(src_data_45, 4.333333333333)), + vmulq_n_f32(src_data_55, 0.66666666666)), + vmulq_n_f32(src_data_65, 1.333333333333)); + float32x4_t t16 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_16, 1.5), vmulq_n_f32(src_data_26, 3)), + vmulq_n_f32(src_data_36, 2.166666666666666667)), + vmulq_n_f32(src_data_46, 4.333333333333)), + vmulq_n_f32(src_data_56, 0.66666666666)), + vmulq_n_f32(src_data_66, 1.333333333333)); + float32x4_t t17 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_17, 1.5), vmulq_n_f32(src_data_27, 3)), + vmulq_n_f32(src_data_37, 2.166666666666666667)), + vmulq_n_f32(src_data_47, 4.333333333333)), + vmulq_n_f32(src_data_57, 0.66666666666)), + vmulq_n_f32(src_data_67, 1.333333333333)); + + float32x4_t t20 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_10, -1.5), vmulq_n_f32(src_data_20, 3)), + vmulq_n_f32(src_data_30, 2.166666666666666667)), + vmulq_n_f32(src_data_40, 4.333333333333)), + vmulq_n_f32(src_data_50, 0.66666666666)), + vmulq_n_f32(src_data_60, 1.333333333333)); + float32x4_t t21 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_11, -1.5), vmulq_n_f32(src_data_21, 3)), + vmulq_n_f32(src_data_31, 2.166666666666666667)), + vmulq_n_f32(src_data_41, 4.333333333333)), + vmulq_n_f32(src_data_51, 0.66666666666)), + vmulq_n_f32(src_data_61, 1.333333333333)); + float32x4_t t22 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_12, -1.5), vmulq_n_f32(src_data_22, 3)), + vmulq_n_f32(src_data_32, 2.166666666666666667)), + vmulq_n_f32(src_data_42, 4.333333333333)), + vmulq_n_f32(src_data_52, 0.66666666666)), + vmulq_n_f32(src_data_62, 1.333333333333)); + float32x4_t t23 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_13, -1.5), vmulq_n_f32(src_data_23, 3)), + vmulq_n_f32(src_data_33, 2.166666666666666667)), + vmulq_n_f32(src_data_43, 4.333333333333)), + vmulq_n_f32(src_data_53, 0.66666666666)), + vmulq_n_f32(src_data_63, 1.333333333333)); + float32x4_t t24 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_14, -1.5), vmulq_n_f32(src_data_24, 3)), + vmulq_n_f32(src_data_34, 2.166666666666666667)), + vmulq_n_f32(src_data_44, 4.333333333333)), + vmulq_n_f32(src_data_54, 0.66666666666)), + vmulq_n_f32(src_data_64, 1.333333333333)); + float32x4_t t25 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_15, -1.5), vmulq_n_f32(src_data_25, 3)), + vmulq_n_f32(src_data_35, 2.166666666666666667)), + vmulq_n_f32(src_data_45, 4.333333333333)), + vmulq_n_f32(src_data_55, 0.66666666666)), + vmulq_n_f32(src_data_65, 1.333333333333)); + float32x4_t t26 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_16, -1.5), vmulq_n_f32(src_data_26, 3)), + vmulq_n_f32(src_data_36, 2.166666666666666667)), + vmulq_n_f32(src_data_46, 4.333333333333)), + vmulq_n_f32(src_data_56, 0.66666666666)), + vmulq_n_f32(src_data_66, 1.333333333333)); + float32x4_t t27 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_17, -1.5), vmulq_n_f32(src_data_27, 3)), + vmulq_n_f32(src_data_37, 2.166666666666666667)), + vmulq_n_f32(src_data_47, 4.333333333333)), + vmulq_n_f32(src_data_57, 0.66666666666)), + vmulq_n_f32(src_data_67, 1.333333333333)); + + float32x4_t t30 = vsubq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(src_data_30, src_data_40), 1.3333333333333), + vmulq_n_f32(vaddq_f32(src_data_10, src_data_20), -0.3)), + vmulq_n_f32(vaddq_f32(src_data_50, src_data_60), 0.53333333333)); + float32x4_t t31 = vsubq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(src_data_31, src_data_41), 1.3333333333333), + vmulq_n_f32(vaddq_f32(src_data_11, src_data_21), -0.3)), + vmulq_n_f32(vaddq_f32(src_data_51, src_data_61), 0.53333333333)); + float32x4_t t32 = vsubq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(src_data_32, src_data_42), 1.3333333333333), + vmulq_n_f32(vaddq_f32(src_data_12, src_data_22), -0.3)), + vmulq_n_f32(vaddq_f32(src_data_52, src_data_62), 0.53333333333)); + float32x4_t t33 = vsubq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(src_data_33, src_data_43), 1.3333333333333), + vmulq_n_f32(vaddq_f32(src_data_13, src_data_23), -0.3)), + vmulq_n_f32(vaddq_f32(src_data_53, src_data_63), 0.53333333333)); + float32x4_t t34 = vsubq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(src_data_34, src_data_44), 1.3333333333333), + vmulq_n_f32(vaddq_f32(src_data_14, src_data_24), -0.3)), + vmulq_n_f32(vaddq_f32(src_data_54, src_data_64), 0.53333333333)); + float32x4_t t35 = vsubq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(src_data_35, src_data_45), 1.3333333333333), + vmulq_n_f32(vaddq_f32(src_data_15, src_data_25), -0.3)), + vmulq_n_f32(vaddq_f32(src_data_55, src_data_65), 0.53333333333)); + float32x4_t t36 = vsubq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(src_data_36, src_data_46), 1.3333333333333), + vmulq_n_f32(vaddq_f32(src_data_16, src_data_26), -0.3)), + vmulq_n_f32(vaddq_f32(src_data_56, src_data_66), 0.53333333333)); + float32x4_t t37 = vsubq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(src_data_37, src_data_47), 1.3333333333333), + vmulq_n_f32(vaddq_f32(src_data_17, src_data_27), -0.3)), + vmulq_n_f32(vaddq_f32(src_data_57, src_data_67), 0.53333333333)); + + float32x4_t t40 = vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(src_data_40, src_data_30), 1.3333333333333), + vmulq_n_f32(vsubq_f32(src_data_10, src_data_20), 0.3)), + vmulq_n_f32(vsubq_f32(src_data_50, src_data_60), 0.53333333333)); + float32x4_t t41 = vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(src_data_41, src_data_31), 1.3333333333333), + vmulq_n_f32(vsubq_f32(src_data_11, src_data_21), 0.3)), + vmulq_n_f32(vsubq_f32(src_data_51, src_data_61), 0.53333333333)); + float32x4_t t42 = vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(src_data_42, src_data_32), 1.3333333333333), + vmulq_n_f32(vsubq_f32(src_data_12, src_data_22), 0.3)), + vmulq_n_f32(vsubq_f32(src_data_52, src_data_62), 0.53333333333)); + float32x4_t t43 = vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(src_data_43, src_data_33), 1.3333333333333), + vmulq_n_f32(vsubq_f32(src_data_13, src_data_23), 0.3)), + vmulq_n_f32(vsubq_f32(src_data_53, src_data_63), 0.53333333333)); + float32x4_t t44 = vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(src_data_44, src_data_34), 1.3333333333333), + vmulq_n_f32(vsubq_f32(src_data_14, src_data_24), 0.3)), + vmulq_n_f32(vsubq_f32(src_data_54, src_data_64), 0.53333333333)); + float32x4_t t45 = vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(src_data_45, src_data_35), 1.3333333333333), + vmulq_n_f32(vsubq_f32(src_data_15, src_data_25), 0.3)), + vmulq_n_f32(vsubq_f32(src_data_55, src_data_65), 0.53333333333)); + float32x4_t t46 = vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(src_data_46, src_data_36), 1.3333333333333), + vmulq_n_f32(vsubq_f32(src_data_16, src_data_26), 0.3)), + vmulq_n_f32(vsubq_f32(src_data_56, src_data_66), 0.53333333333)); + float32x4_t t47 = vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(src_data_47, src_data_37), 1.3333333333333), + vmulq_n_f32(vsubq_f32(src_data_17, src_data_27), 0.3)), + vmulq_n_f32(vsubq_f32(src_data_57, src_data_67), 0.53333333333)); + + float32x4_t t50 = vaddq_f32( + vaddq_f32( + vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_10, 0.03333333), vmulq_n_f32(src_data_20, 0.022222222)), + vmulq_n_f32(src_data_30, 0.1666666666)), + vmulq_n_f32(src_data_40, 0.11111111111)), + vmulq_n_f32(src_data_50, 0.133333333)), + vmulq_n_f32(src_data_60, 0.088888888)); + float32x4_t t51 = vaddq_f32( + vaddq_f32( + vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_11, 0.03333333), vmulq_n_f32(src_data_21, 0.022222222)), + vmulq_n_f32(src_data_31, 0.1666666666)), + vmulq_n_f32(src_data_41, 0.11111111111)), + vmulq_n_f32(src_data_51, 0.133333333)), + vmulq_n_f32(src_data_61, 0.088888888)); + float32x4_t t52 = vaddq_f32( + vaddq_f32( + vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_12, 0.03333333), vmulq_n_f32(src_data_22, 0.022222222)), + vmulq_n_f32(src_data_32, 0.1666666666)), + vmulq_n_f32(src_data_42, 0.11111111111)), + vmulq_n_f32(src_data_52, 0.133333333)), + vmulq_n_f32(src_data_62, 0.088888888)); + float32x4_t t53 = vaddq_f32( + vaddq_f32( + vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_13, 0.03333333), vmulq_n_f32(src_data_23, 0.022222222)), + vmulq_n_f32(src_data_33, 0.1666666666)), + vmulq_n_f32(src_data_43, 0.11111111111)), + vmulq_n_f32(src_data_53, 0.133333333)), + vmulq_n_f32(src_data_63, 0.088888888)); + float32x4_t t54 = vaddq_f32( + vaddq_f32( + vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_14, 0.03333333), vmulq_n_f32(src_data_24, 0.022222222)), + vmulq_n_f32(src_data_34, 0.1666666666)), + vmulq_n_f32(src_data_44, 0.11111111111)), + vmulq_n_f32(src_data_54, 0.133333333)), + vmulq_n_f32(src_data_64, 0.088888888)); + float32x4_t t55 = vaddq_f32( + vaddq_f32( + vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_15, 0.03333333), vmulq_n_f32(src_data_25, 0.022222222)), + vmulq_n_f32(src_data_35, 0.1666666666)), + vmulq_n_f32(src_data_45, 0.11111111111)), + vmulq_n_f32(src_data_55, 0.133333333)), + vmulq_n_f32(src_data_65, 0.088888888)); + float32x4_t t56 = vaddq_f32( + vaddq_f32( + vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_16, 0.03333333), vmulq_n_f32(src_data_26, 0.022222222)), + vmulq_n_f32(src_data_36, 0.1666666666)), + vmulq_n_f32(src_data_46, 0.11111111111)), + vmulq_n_f32(src_data_56, 0.133333333)), + vmulq_n_f32(src_data_66, 0.088888888)); + float32x4_t t57 = vaddq_f32( + vaddq_f32( + vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_17, 0.03333333), vmulq_n_f32(src_data_27, 0.022222222)), + vmulq_n_f32(src_data_37, 0.1666666666)), + vmulq_n_f32(src_data_47, 0.11111111111)), + vmulq_n_f32(src_data_57, 0.133333333)), + vmulq_n_f32(src_data_67, 0.088888888)); + + float32x4_t t60 = vaddq_f32( + vaddq_f32( + vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_10, -0.03333333), vmulq_n_f32(src_data_20, 0.022222222)), + vmulq_n_f32(src_data_30, 0.1666666666)), + vmulq_n_f32(src_data_40, 0.11111111111)), + vmulq_n_f32(src_data_50, -0.133333333)), + vmulq_n_f32(src_data_60, 0.088888888)); + float32x4_t t61 = vaddq_f32( + vaddq_f32( + vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_11, -0.03333333), vmulq_n_f32(src_data_21, 0.022222222)), + vmulq_n_f32(src_data_31, 0.1666666666)), + vmulq_n_f32(src_data_41, 0.11111111111)), + vmulq_n_f32(src_data_51, -0.133333333)), + vmulq_n_f32(src_data_61, 0.088888888)); + float32x4_t t62 = vaddq_f32( + vaddq_f32( + vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_12, -0.03333333), vmulq_n_f32(src_data_22, 0.022222222)), + vmulq_n_f32(src_data_32, 0.1666666666)), + vmulq_n_f32(src_data_42, 0.11111111111)), + vmulq_n_f32(src_data_52, -0.133333333)), + vmulq_n_f32(src_data_62, 0.088888888)); + float32x4_t t63 = vaddq_f32( + vaddq_f32( + vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_13, -0.03333333), vmulq_n_f32(src_data_23, 0.022222222)), + vmulq_n_f32(src_data_33, 0.1666666666)), + vmulq_n_f32(src_data_43, 0.11111111111)), + vmulq_n_f32(src_data_53, -0.133333333)), + vmulq_n_f32(src_data_63, 0.088888888)); + float32x4_t t64 = vaddq_f32( + vaddq_f32( + vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_14, -0.03333333), vmulq_n_f32(src_data_24, 0.022222222)), + vmulq_n_f32(src_data_34, 0.1666666666)), + vmulq_n_f32(src_data_44, 0.11111111111)), + vmulq_n_f32(src_data_54, -0.133333333)), + vmulq_n_f32(src_data_64, 0.088888888)); + float32x4_t t65 = vaddq_f32( + vaddq_f32( + vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_15, -0.03333333), vmulq_n_f32(src_data_25, 0.022222222)), + vmulq_n_f32(src_data_35, 0.1666666666)), + vmulq_n_f32(src_data_45, 0.11111111111)), + vmulq_n_f32(src_data_55, -0.133333333)), + vmulq_n_f32(src_data_65, 0.088888888)); + float32x4_t t66 = vaddq_f32( + vaddq_f32( + vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_16, -0.03333333), vmulq_n_f32(src_data_26, 0.022222222)), + vmulq_n_f32(src_data_36, 0.1666666666)), + vmulq_n_f32(src_data_46, 0.11111111111)), + vmulq_n_f32(src_data_56, -0.133333333)), + vmulq_n_f32(src_data_66, 0.088888888)); + float32x4_t t67 = vaddq_f32( + vaddq_f32( + vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(src_data_17, -0.03333333), vmulq_n_f32(src_data_27, 0.022222222)), + vmulq_n_f32(src_data_37, 0.1666666666)), + vmulq_n_f32(src_data_47, 0.11111111111)), + vmulq_n_f32(src_data_57, -0.133333333)), + vmulq_n_f32(src_data_67, 0.088888888)); + + float32x4_t t70 = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_30, 3.0625), vmulq_n_f32(src_data_10, -0.5625)), + vmulq_n_f32(src_data_50, 3.5)), + src_data_70); + float32x4_t t71 = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_31, 3.0625), vmulq_n_f32(src_data_11, -0.5625)), + vmulq_n_f32(src_data_51, 3.5)), + src_data_71); + float32x4_t t72 = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_32, 3.0625), vmulq_n_f32(src_data_12, -0.5625)), + vmulq_n_f32(src_data_52, 3.5)), + src_data_72); + float32x4_t t73 = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_33, 3.0625), vmulq_n_f32(src_data_13, -0.5625)), + vmulq_n_f32(src_data_53, 3.5)), + src_data_73); + float32x4_t t74 = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_34, 3.0625), vmulq_n_f32(src_data_14, -0.5625)), + vmulq_n_f32(src_data_54, 3.5)), + src_data_74); + float32x4_t t75 = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_35, 3.0625), vmulq_n_f32(src_data_15, -0.5625)), + vmulq_n_f32(src_data_55, 3.5)), + src_data_75); + float32x4_t t76 = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_36, 3.0625), vmulq_n_f32(src_data_16, -0.5625)), + vmulq_n_f32(src_data_56, 3.5)), + src_data_76); + float32x4_t t77 = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src_data_37, 3.0625), vmulq_n_f32(src_data_17, -0.5625)), + vmulq_n_f32(src_data_57, 3.5)), + src_data_77); + + float32x4_t m00 = + vsubq_f32(vaddq_f32(vsubq_f32(t00, vmulq_n_f32(t02, 5.444444444444444)), vmulq_n_f32(t04, 6.22222222222)), + vmulq_n_f32(t06, 1.77777777777777777778)); + float32x4_t m01 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t01, 1.5), vmulq_n_f32(t02, 3)), + vmulq_n_f32(t03, 2.16666666666666667)), + vmulq_n_f32(t04, 4.3333333333)), + vmulq_n_f32(t05, 0.66666666667)), + vmulq_n_f32(t06, 1.333333333333)); + float32x4_t m02 = vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t01, -1.5), vmulq_n_f32(t02, 3)), + vmulq_n_f32(t03, 2.16666666666666667)), + vmulq_n_f32(t04, 4.3333333333)), + vmulq_n_f32(t05, 0.66666666667)), + vmulq_n_f32(t06, 1.333333333333)); + float32x4_t m03 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t01, t02), -0.3), vmulq_n_f32(vaddq_f32(t03, t04), 1.33333333333)), + vmulq_n_f32(vaddq_f32(t05, t06), -0.533333333333)); + float32x4_t m04 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t01, t02), 0.3), vmulq_n_f32(vsubq_f32(t03, t04), 1.33333333333)), + vmulq_n_f32(vsubq_f32(t05, t06), 0.533333333333)); + float32x4_t m05 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t01, 0.03333333), vmulq_n_f32(t02, 0.0222222)), + vmulq_n_f32(t03, 0.16666666666666667)), + vmulq_n_f32(t04, 0.11111111111)), + vmulq_n_f32(t05, 0.1333333333)), + vmulq_n_f32(t06, 0.08888888888)); + float32x4_t m06 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t01, -0.03333333), vmulq_n_f32(t02, 0.0222222)), + vmulq_n_f32(t03, 0.16666666666666667)), + vmulq_n_f32(t04, 0.11111111111)), + vmulq_n_f32(t05, 0.1333333333)), + vmulq_n_f32(t06, 0.08888888888)); + float32x4_t m07 = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t01, -0.5625), vmulq_n_f32(t03, 3.0625)), vmulq_n_f32(t05, 3.5)), t07); + + float32x4_t m10 = + vsubq_f32(vaddq_f32(vsubq_f32(t10, vmulq_n_f32(t12, 5.444444444444444)), vmulq_n_f32(t14, 6.22222222222)), + vmulq_n_f32(t16, 1.77777777777777777778)); + float32x4_t m11 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t11, 1.5), vmulq_n_f32(t12, 3)), + vmulq_n_f32(t13, 2.16666666666666667)), + vmulq_n_f32(t14, 4.3333333333)), + vmulq_n_f32(t15, 0.66666666667)), + vmulq_n_f32(t16, 1.333333333333)); + float32x4_t m12 = vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t11, -1.5), vmulq_n_f32(t12, 3)), + vmulq_n_f32(t13, 2.16666666666666667)), + vmulq_n_f32(t14, 4.3333333333)), + vmulq_n_f32(t15, 0.66666666667)), + vmulq_n_f32(t16, 1.333333333333)); + float32x4_t m13 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t11, t12), -0.3), vmulq_n_f32(vaddq_f32(t13, t14), 1.33333333333)), + vmulq_n_f32(vaddq_f32(t15, t16), -0.533333333333)); + float32x4_t m14 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t11, t12), 0.3), vmulq_n_f32(vsubq_f32(t13, t14), 1.33333333333)), + vmulq_n_f32(vsubq_f32(t15, t16), 0.533333333333)); + float32x4_t m15 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t11, 0.03333333), vmulq_n_f32(t12, 0.0222222)), + vmulq_n_f32(t13, 0.16666666666666667)), + vmulq_n_f32(t14, 0.11111111111)), + vmulq_n_f32(t15, 0.1333333333)), + vmulq_n_f32(t16, 0.08888888888)); + float32x4_t m16 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t11, -0.03333333), vmulq_n_f32(t12, 0.0222222)), + vmulq_n_f32(t13, 0.16666666666666667)), + vmulq_n_f32(t14, 0.11111111111)), + vmulq_n_f32(t15, 0.1333333333)), + vmulq_n_f32(t16, 0.08888888888)); + float32x4_t m17 = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t11, -0.5625), vmulq_n_f32(t13, 3.0625)), vmulq_n_f32(t15, 3.5)), t17); + + float32x4_t m20 = + vsubq_f32(vaddq_f32(vsubq_f32(t20, vmulq_n_f32(t22, 5.444444444444444)), vmulq_n_f32(t24, 6.22222222222)), + vmulq_n_f32(t26, 1.77777777777777777778)); + float32x4_t m21 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t21, 1.5), vmulq_n_f32(t22, 3)), + vmulq_n_f32(t23, 2.16666666666666667)), + vmulq_n_f32(t24, 4.3333333333)), + vmulq_n_f32(t25, 0.66666666667)), + vmulq_n_f32(t26, 1.333333333333)); + float32x4_t m22 = vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t21, -1.5), vmulq_n_f32(t22, 3)), + vmulq_n_f32(t23, 2.16666666666666667)), + vmulq_n_f32(t24, 4.3333333333)), + vmulq_n_f32(t25, 0.66666666667)), + vmulq_n_f32(t26, 1.333333333333)); + float32x4_t m23 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t21, t22), -0.3), vmulq_n_f32(vaddq_f32(t23, t24), 1.33333333333)), + vmulq_n_f32(vaddq_f32(t25, t26), -0.533333333333)); + float32x4_t m24 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t21, t22), 0.3), vmulq_n_f32(vsubq_f32(t23, t24), 1.33333333333)), + vmulq_n_f32(vsubq_f32(t25, t26), 0.533333333333)); + float32x4_t m25 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t21, 0.03333333), vmulq_n_f32(t22, 0.0222222)), + vmulq_n_f32(t23, 0.16666666666666667)), + vmulq_n_f32(t24, 0.11111111111)), + vmulq_n_f32(t25, 0.1333333333)), + vmulq_n_f32(t26, 0.08888888888)); + float32x4_t m26 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t21, -0.03333333), vmulq_n_f32(t22, 0.0222222)), + vmulq_n_f32(t23, 0.16666666666666667)), + vmulq_n_f32(t24, 0.11111111111)), + vmulq_n_f32(t25, 0.1333333333)), + vmulq_n_f32(t26, 0.08888888888)); + float32x4_t m27 = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t21, -0.5625), vmulq_n_f32(t23, 3.0625)), vmulq_n_f32(t25, 3.5)), t27); + + float32x4_t m30 = + vsubq_f32(vaddq_f32(vsubq_f32(t30, vmulq_n_f32(t32, 5.444444444444444)), vmulq_n_f32(t34, 6.22222222222)), + vmulq_n_f32(t36, 1.77777777777777777778)); + float32x4_t m31 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t31, 1.5), vmulq_n_f32(t32, 3)), + vmulq_n_f32(t33, 2.16666666666666667)), + vmulq_n_f32(t34, 4.3333333333)), + vmulq_n_f32(t35, 0.66666666667)), + vmulq_n_f32(t36, 1.333333333333)); + float32x4_t m32 = vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t31, -1.5), vmulq_n_f32(t32, 3)), + vmulq_n_f32(t33, 2.16666666666666667)), + vmulq_n_f32(t34, 4.3333333333)), + vmulq_n_f32(t35, 0.66666666667)), + vmulq_n_f32(t36, 1.333333333333)); + float32x4_t m33 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t31, t32), -0.3), vmulq_n_f32(vaddq_f32(t33, t34), 1.33333333333)), + vmulq_n_f32(vaddq_f32(t35, t36), -0.533333333333)); + float32x4_t m34 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t31, t32), 0.3), vmulq_n_f32(vsubq_f32(t33, t34), 1.33333333333)), + vmulq_n_f32(vsubq_f32(t35, t36), 0.533333333333)); + float32x4_t m35 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t31, 0.03333333), vmulq_n_f32(t32, 0.0222222)), + vmulq_n_f32(t33, 0.16666666666666667)), + vmulq_n_f32(t34, 0.11111111111)), + vmulq_n_f32(t35, 0.1333333333)), + vmulq_n_f32(t36, 0.08888888888)); + float32x4_t m36 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t31, -0.03333333), vmulq_n_f32(t32, 0.0222222)), + vmulq_n_f32(t33, 0.16666666666666667)), + vmulq_n_f32(t34, 0.11111111111)), + vmulq_n_f32(t35, 0.1333333333)), + vmulq_n_f32(t36, 0.08888888888)); + float32x4_t m37 = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t31, -0.5625), vmulq_n_f32(t33, 3.0625)), vmulq_n_f32(t35, 3.5)), t37); + + float32x4_t m40 = + vsubq_f32(vaddq_f32(vsubq_f32(t40, vmulq_n_f32(t42, 5.444444444444444)), vmulq_n_f32(t44, 6.22222222222)), + vmulq_n_f32(t46, 1.77777777777777777778)); + float32x4_t m41 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t41, 1.5), vmulq_n_f32(t42, 3)), + vmulq_n_f32(t43, 2.16666666666666667)), + vmulq_n_f32(t44, 4.3333333333)), + vmulq_n_f32(t45, 0.66666666667)), + vmulq_n_f32(t46, 1.333333333333)); + float32x4_t m42 = vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t41, -1.5), vmulq_n_f32(t42, 3)), + vmulq_n_f32(t43, 2.16666666666666667)), + vmulq_n_f32(t44, 4.3333333333)), + vmulq_n_f32(t45, 0.66666666667)), + vmulq_n_f32(t46, 1.333333333333)); + float32x4_t m43 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t41, t42), -0.3), vmulq_n_f32(vaddq_f32(t43, t44), 1.33333333333)), + vmulq_n_f32(vaddq_f32(t45, t46), -0.533333333333)); + float32x4_t m44 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t41, t42), 0.3), vmulq_n_f32(vsubq_f32(t43, t44), 1.33333333333)), + vmulq_n_f32(vsubq_f32(t45, t46), 0.533333333333)); + float32x4_t m45 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t41, 0.03333333), vmulq_n_f32(t42, 0.0222222)), + vmulq_n_f32(t43, 0.16666666666666667)), + vmulq_n_f32(t44, 0.11111111111)), + vmulq_n_f32(t45, 0.1333333333)), + vmulq_n_f32(t46, 0.08888888888)); + float32x4_t m46 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t41, -0.03333333), vmulq_n_f32(t42, 0.0222222)), + vmulq_n_f32(t43, 0.16666666666666667)), + vmulq_n_f32(t44, 0.11111111111)), + vmulq_n_f32(t45, 0.1333333333)), + vmulq_n_f32(t46, 0.08888888888)); + float32x4_t m47 = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t41, -0.5625), vmulq_n_f32(t43, 3.0625)), vmulq_n_f32(t45, 3.5)), t47); + + float32x4_t m50 = + vsubq_f32(vaddq_f32(vsubq_f32(t50, vmulq_n_f32(t52, 5.444444444444444)), vmulq_n_f32(t54, 6.22222222222)), + vmulq_n_f32(t56, 1.77777777777777777778)); + float32x4_t m51 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t51, 1.5), vmulq_n_f32(t52, 3)), + vmulq_n_f32(t53, 2.16666666666666667)), + vmulq_n_f32(t54, 4.3333333333)), + vmulq_n_f32(t55, 0.66666666667)), + vmulq_n_f32(t56, 1.333333333333)); + float32x4_t m52 = vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t51, -1.5), vmulq_n_f32(t52, 3)), + vmulq_n_f32(t53, 2.16666666666666667)), + vmulq_n_f32(t54, 4.3333333333)), + vmulq_n_f32(t55, 0.66666666667)), + vmulq_n_f32(t56, 1.333333333333)); + float32x4_t m53 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t51, t52), -0.3), vmulq_n_f32(vaddq_f32(t53, t54), 1.33333333333)), + vmulq_n_f32(vaddq_f32(t55, t56), -0.533333333333)); + float32x4_t m54 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t51, t52), 0.3), vmulq_n_f32(vsubq_f32(t53, t54), 1.33333333333)), + vmulq_n_f32(vsubq_f32(t55, t56), 0.533333333333)); + float32x4_t m55 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t51, 0.03333333), vmulq_n_f32(t52, 0.0222222)), + vmulq_n_f32(t53, 0.16666666666666667)), + vmulq_n_f32(t54, 0.11111111111)), + vmulq_n_f32(t55, 0.1333333333)), + vmulq_n_f32(t56, 0.08888888888)); + float32x4_t m56 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t51, -0.03333333), vmulq_n_f32(t52, 0.0222222)), + vmulq_n_f32(t53, 0.16666666666666667)), + vmulq_n_f32(t54, 0.11111111111)), + vmulq_n_f32(t55, 0.1333333333)), + vmulq_n_f32(t56, 0.08888888888)); + float32x4_t m57 = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t51, -0.5625), vmulq_n_f32(t53, 3.0625)), vmulq_n_f32(t55, 3.5)), t57); + + float32x4_t m60 = + vsubq_f32(vaddq_f32(vsubq_f32(t60, vmulq_n_f32(t62, 5.444444444444444)), vmulq_n_f32(t64, 6.22222222222)), + vmulq_n_f32(t66, 1.77777777777777777778)); + float32x4_t m61 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t61, 1.5), vmulq_n_f32(t62, 3)), + vmulq_n_f32(t63, 2.16666666666666667)), + vmulq_n_f32(t64, 4.3333333333)), + vmulq_n_f32(t65, 0.66666666667)), + vmulq_n_f32(t66, 1.333333333333)); + float32x4_t m62 = vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t61, -1.5), vmulq_n_f32(t62, 3)), + vmulq_n_f32(t63, 2.16666666666666667)), + vmulq_n_f32(t64, 4.3333333333)), + vmulq_n_f32(t65, 0.66666666667)), + vmulq_n_f32(t66, 1.333333333333)); + float32x4_t m63 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t61, t62), -0.3), vmulq_n_f32(vaddq_f32(t63, t64), 1.33333333333)), + vmulq_n_f32(vaddq_f32(t65, t66), -0.533333333333)); + float32x4_t m64 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t61, t62), 0.3), vmulq_n_f32(vsubq_f32(t63, t64), 1.33333333333)), + vmulq_n_f32(vsubq_f32(t65, t66), 0.533333333333)); + float32x4_t m65 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t61, 0.03333333), vmulq_n_f32(t62, 0.0222222)), + vmulq_n_f32(t63, 0.16666666666666667)), + vmulq_n_f32(t64, 0.11111111111)), + vmulq_n_f32(t65, 0.1333333333)), + vmulq_n_f32(t66, 0.08888888888)); + float32x4_t m66 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t61, -0.03333333), vmulq_n_f32(t62, 0.0222222)), + vmulq_n_f32(t63, 0.16666666666666667)), + vmulq_n_f32(t64, 0.11111111111)), + vmulq_n_f32(t65, 0.1333333333)), + vmulq_n_f32(t66, 0.08888888888)); + float32x4_t m67 = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t61, -0.5625), vmulq_n_f32(t63, 3.0625)), vmulq_n_f32(t65, 3.5)), t67); + + float32x4_t m70 = + vsubq_f32(vaddq_f32(vsubq_f32(t70, vmulq_n_f32(t72, 5.444444444444444)), vmulq_n_f32(t74, 6.22222222222)), + vmulq_n_f32(t76, 1.77777777777777777778)); + float32x4_t m71 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t71, 1.5), vmulq_n_f32(t72, 3)), + vmulq_n_f32(t73, 2.16666666666666667)), + vmulq_n_f32(t74, 4.3333333333)), + vmulq_n_f32(t75, 0.66666666667)), + vmulq_n_f32(t76, 1.333333333333)); + float32x4_t m72 = vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t71, -1.5), vmulq_n_f32(t72, 3)), + vmulq_n_f32(t73, 2.16666666666666667)), + vmulq_n_f32(t74, 4.3333333333)), + vmulq_n_f32(t75, 0.66666666667)), + vmulq_n_f32(t76, 1.333333333333)); + float32x4_t m73 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t71, t72), -0.3), vmulq_n_f32(vaddq_f32(t73, t74), 1.33333333333)), + vmulq_n_f32(vaddq_f32(t75, t76), -0.533333333333)); + float32x4_t m74 = + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t71, t72), 0.3), vmulq_n_f32(vsubq_f32(t73, t74), 1.33333333333)), + vmulq_n_f32(vsubq_f32(t75, t76), 0.533333333333)); + float32x4_t m75 = + vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t71, 0.03333333), vmulq_n_f32(t72, 0.0222222)), + vmulq_n_f32(t73, 0.16666666666666667)), + vmulq_n_f32(t74, 0.11111111111)), + vmulq_n_f32(t75, 0.1333333333)), + vmulq_n_f32(t76, 0.08888888888)); + float32x4_t m76 = + vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(t71, -0.03333333), vmulq_n_f32(t72, 0.0222222)), + vmulq_n_f32(t73, 0.16666666666666667)), + vmulq_n_f32(t74, 0.11111111111)), + vmulq_n_f32(t75, 0.1333333333)), + vmulq_n_f32(t76, 0.08888888888)); + float32x4_t m77 = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t71, -0.5625), vmulq_n_f32(t73, 3.0625)), vmulq_n_f32(t75, 3.5)), t77); + + vst1q_f32(dst_data + 0 * dst_step, m00); + vst1q_f32(dst_data + 1 * dst_step, m01); + vst1q_f32(dst_data + 2 * dst_step, m02); + vst1q_f32(dst_data + 3 * dst_step, m03); + vst1q_f32(dst_data + 4 * dst_step, m04); + vst1q_f32(dst_data + 5 * dst_step, m05); + vst1q_f32(dst_data + 6 * dst_step, m06); + vst1q_f32(dst_data + 7 * dst_step, m07); + vst1q_f32(dst_data + 8 * dst_step, m10); + vst1q_f32(dst_data + 9 * dst_step, m11); + vst1q_f32(dst_data + 10 * dst_step, m12); + vst1q_f32(dst_data + 11 * dst_step, m13); + vst1q_f32(dst_data + 12 * dst_step, m14); + vst1q_f32(dst_data + 13 * dst_step, m15); + vst1q_f32(dst_data + 14 * dst_step, m16); + vst1q_f32(dst_data + 15 * dst_step, m17); + vst1q_f32(dst_data + 16 * dst_step, m20); + vst1q_f32(dst_data + 17 * dst_step, m21); + vst1q_f32(dst_data + 18 * dst_step, m22); + vst1q_f32(dst_data + 19 * dst_step, m23); + vst1q_f32(dst_data + 20 * dst_step, m24); + vst1q_f32(dst_data + 21 * dst_step, m25); + vst1q_f32(dst_data + 22 * dst_step, m26); + vst1q_f32(dst_data + 23 * dst_step, m27); + vst1q_f32(dst_data + 24 * dst_step, m30); + vst1q_f32(dst_data + 25 * dst_step, m31); + vst1q_f32(dst_data + 26 * dst_step, m32); + vst1q_f32(dst_data + 27 * dst_step, m33); + vst1q_f32(dst_data + 28 * dst_step, m34); + vst1q_f32(dst_data + 29 * dst_step, m35); + vst1q_f32(dst_data + 30 * dst_step, m36); + vst1q_f32(dst_data + 31 * dst_step, m37); + vst1q_f32(dst_data + 32 * dst_step, m40); + vst1q_f32(dst_data + 33 * dst_step, m41); + vst1q_f32(dst_data + 34 * dst_step, m42); + vst1q_f32(dst_data + 35 * dst_step, m43); + vst1q_f32(dst_data + 36 * dst_step, m44); + vst1q_f32(dst_data + 37 * dst_step, m45); + vst1q_f32(dst_data + 38 * dst_step, m46); + vst1q_f32(dst_data + 39 * dst_step, m47); + vst1q_f32(dst_data + 40 * dst_step, m50); + vst1q_f32(dst_data + 41 * dst_step, m51); + vst1q_f32(dst_data + 42 * dst_step, m52); + vst1q_f32(dst_data + 43 * dst_step, m53); + vst1q_f32(dst_data + 44 * dst_step, m54); + vst1q_f32(dst_data + 45 * dst_step, m55); + vst1q_f32(dst_data + 46 * dst_step, m56); + vst1q_f32(dst_data + 47 * dst_step, m57); + vst1q_f32(dst_data + 48 * dst_step, m60); + vst1q_f32(dst_data + 49 * dst_step, m61); + vst1q_f32(dst_data + 50 * dst_step, m62); + vst1q_f32(dst_data + 51 * dst_step, m63); + vst1q_f32(dst_data + 52 * dst_step, m64); + vst1q_f32(dst_data + 53 * dst_step, m65); + vst1q_f32(dst_data + 54 * dst_step, m66); + vst1q_f32(dst_data + 55 * dst_step, m67); + vst1q_f32(dst_data + 56 * dst_step, m70); + vst1q_f32(dst_data + 57 * dst_step, m71); + vst1q_f32(dst_data + 58 * dst_step, m72); + vst1q_f32(dst_data + 59 * dst_step, m73); + vst1q_f32(dst_data + 60 * dst_step, m74); + vst1q_f32(dst_data + 61 * dst_step, m75); + vst1q_f32(dst_data + 62 * dst_step, m76); + vst1q_f32(dst_data + 63 * dst_step, m77); +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_04 = src_data[i + 4 * src_step]; + float src_data_05 = src_data[i + 5 * src_step]; + float src_data_06 = src_data[i + 6 * src_step]; + float src_data_07 = src_data[i + 7 * src_step]; + float src_data_10 = src_data[i + 8 * src_step]; + float src_data_11 = src_data[i + 9 * src_step]; + float src_data_12 = src_data[i + 10 * src_step]; + float src_data_13 = src_data[i + 11 * src_step]; + float src_data_14 = src_data[i + 12 * src_step]; + float src_data_15 = src_data[i + 13 * src_step]; + float src_data_16 = src_data[i + 14 * src_step]; + float src_data_17 = src_data[i + 15 * src_step]; + float src_data_20 = src_data[i + 16 * src_step]; + float src_data_21 = src_data[i + 17 * src_step]; + float src_data_22 = src_data[i + 18 * src_step]; + float src_data_23 = src_data[i + 19 * src_step]; + float src_data_24 = src_data[i + 20 * src_step]; + float src_data_25 = src_data[i + 21 * src_step]; + float src_data_26 = src_data[i + 22 * src_step]; + float src_data_27 = src_data[i + 23 * src_step]; + float src_data_30 = src_data[i + 24 * src_step]; + float src_data_31 = src_data[i + 25 * src_step]; + float src_data_32 = src_data[i + 26 * src_step]; + float src_data_33 = src_data[i + 27 * src_step]; + float src_data_34 = src_data[i + 28 * src_step]; + float src_data_35 = src_data[i + 29 * src_step]; + float src_data_36 = src_data[i + 30 * src_step]; + float src_data_37 = src_data[i + 31 * src_step]; + float src_data_40 = src_data[i + 32 * src_step]; + float src_data_41 = src_data[i + 33 * src_step]; + float src_data_42 = src_data[i + 34 * src_step]; + float src_data_43 = src_data[i + 35 * src_step]; + float src_data_44 = src_data[i + 36 * src_step]; + float src_data_45 = src_data[i + 37 * src_step]; + float src_data_46 = src_data[i + 38 * src_step]; + float src_data_47 = src_data[i + 39 * src_step]; + float src_data_50 = src_data[i + 40 * src_step]; + float src_data_51 = src_data[i + 41 * src_step]; + float src_data_52 = src_data[i + 42 * src_step]; + float src_data_53 = src_data[i + 43 * src_step]; + float src_data_54 = src_data[i + 44 * src_step]; + float src_data_55 = src_data[i + 45 * src_step]; + float src_data_56 = src_data[i + 46 * src_step]; + float src_data_57 = src_data[i + 47 * src_step]; + float src_data_60 = src_data[i + 48 * src_step]; + float src_data_61 = src_data[i + 49 * src_step]; + float src_data_62 = src_data[i + 50 * src_step]; + float src_data_63 = src_data[i + 51 * src_step]; + float src_data_64 = src_data[i + 52 * src_step]; + float src_data_65 = src_data[i + 53 * src_step]; + float src_data_66 = src_data[i + 54 * src_step]; + float src_data_67 = src_data[i + 55 * src_step]; + float src_data_70 = src_data[i + 56 * src_step]; + float src_data_71 = src_data[i + 57 * src_step]; + float src_data_72 = src_data[i + 58 * src_step]; + float src_data_73 = src_data[i + 59 * src_step]; + float src_data_74 = src_data[i + 60 * src_step]; + float src_data_75 = src_data[i + 61 * src_step]; + float src_data_76 = src_data[i + 62 * src_step]; + float src_data_77 = src_data[i + 63 * src_step]; + + float t00 = src_data_00 - 5.444444444444444445125f * src_data_20 + 6.222222222222222222223f * src_data_40 - + 1.77777777777777778f * src_data_60; + float t01 = src_data_01 - 5.444444444444444445125f * src_data_21 + 6.222222222222222222223f * src_data_41 - + 1.77777777777777778f * src_data_61; + float t02 = src_data_02 - 5.444444444444444445125f * src_data_22 + 6.222222222222222222223f * src_data_42 - + 1.77777777777777778f * src_data_62; + float t03 = src_data_03 - 5.444444444444444445125f * src_data_23 + 6.222222222222222222223f * src_data_43 - + 1.77777777777777778f * src_data_63; + float t04 = src_data_04 - 5.444444444444444445125f * src_data_24 + 6.222222222222222222223f * src_data_44 - + 1.77777777777777778f * src_data_64; + float t05 = src_data_05 - 5.444444444444444445125f * src_data_25 + 6.222222222222222222223f * src_data_45 - + 1.77777777777777778f * src_data_65; + float t06 = src_data_06 - 5.444444444444444445125f * src_data_26 + 6.222222222222222222223f * src_data_46 - + 1.77777777777777778f * src_data_66; + float t07 = src_data_07 - 5.444444444444444445125f * src_data_27 + 6.222222222222222222223f * src_data_47 - + 1.77777777777777778f * src_data_67; + + float t10 = 1.5f * src_data_10 + 3.0f * src_data_20 - 2.1666666666666667f * src_data_30 - + 4.333333333333333333f * src_data_40 + 0.66666666666666667f * src_data_50 + + 1.333333333333333f * src_data_60; + float t11 = 1.5f * src_data_11 + 3.0f * src_data_21 - 2.1666666666666667f * src_data_31 - + 4.333333333333333333f * src_data_41 + 0.66666666666666667f * src_data_51 + + 1.333333333333333f * src_data_61; + float t12 = 1.5f * src_data_12 + 3.0f * src_data_22 - 2.1666666666666667f * src_data_32 - + 4.333333333333333333f * src_data_42 + 0.66666666666666667f * src_data_52 + + 1.333333333333333f * src_data_62; + float t13 = 1.5f * src_data_13 + 3.0f * src_data_23 - 2.1666666666666667f * src_data_33 - + 4.333333333333333333f * src_data_43 + 0.66666666666666667f * src_data_53 + + 1.333333333333333f * src_data_63; + float t14 = 1.5f * src_data_14 + 3.0f * src_data_24 - 2.1666666666666667f * src_data_34 - + 4.333333333333333333f * src_data_44 + 0.66666666666666667f * src_data_54 + + 1.333333333333333f * src_data_64; + float t15 = 1.5f * src_data_15 + 3.0f * src_data_25 - 2.1666666666666667f * src_data_35 - + 4.333333333333333333f * src_data_45 + 0.66666666666666667f * src_data_55 + + 1.333333333333333f * src_data_65; + float t16 = 1.5f * src_data_16 + 3.0f * src_data_26 - 2.1666666666666667f * src_data_36 - + 4.333333333333333333f * src_data_46 + 0.66666666666666667f * src_data_56 + + 1.333333333333333f * src_data_66; + float t17 = 1.5f * src_data_17 + 3.0f * src_data_27 - 2.1666666666666667f * src_data_37 - + 4.333333333333333333f * src_data_47 + 0.66666666666666667f * src_data_57 + + 1.333333333333333f * src_data_67; + + float t20 = -1.5f * src_data_10 + 3.0f * src_data_20 + 2.1666666666666667f * src_data_30 - + 4.333333333333333333f * src_data_40 - 0.66666666666666667f * src_data_50 + + 1.333333333333333f * src_data_60; + float t21 = -1.5f * src_data_11 + 3.0f * src_data_21 + 2.1666666666666667f * src_data_31 - + 4.333333333333333333f * src_data_41 - 0.66666666666666667f * src_data_51 + + 1.333333333333333f * src_data_61; + float t22 = -1.5f * src_data_12 + 3.0f * src_data_22 + 2.1666666666666667f * src_data_32 - + 4.333333333333333333f * src_data_42 - 0.66666666666666667f * src_data_52 + + 1.333333333333333f * src_data_62; + float t23 = -1.5f * src_data_13 + 3.0f * src_data_23 + 2.1666666666666667f * src_data_33 - + 4.333333333333333333f * src_data_43 - 0.66666666666666667f * src_data_53 + + 1.333333333333333f * src_data_63; + float t24 = -1.5f * src_data_14 + 3.0f * src_data_24 + 2.1666666666666667f * src_data_34 - + 4.333333333333333333f * src_data_44 - 0.66666666666666667f * src_data_54 + + 1.333333333333333f * src_data_64; + float t25 = -1.5f * src_data_15 + 3.0f * src_data_25 + 2.1666666666666667f * src_data_35 - + 4.333333333333333333f * src_data_45 - 0.66666666666666667f * src_data_55 + + 1.333333333333333f * src_data_65; + float t26 = -1.5f * src_data_16 + 3.0f * src_data_26 + 2.1666666666666667f * src_data_36 - + 4.333333333333333333f * src_data_46 - 0.66666666666666667f * src_data_56 + + 1.333333333333333f * src_data_66; + float t27 = -1.5f * src_data_17 + 3.0f * src_data_27 + 2.1666666666666667f * src_data_37 - + 4.333333333333333333f * src_data_47 - 0.66666666666666667f * src_data_57 + + 1.333333333333333f * src_data_67; + + float t30 = -0.3f * (src_data_10 + src_data_20) + 1.33333333333333f * (src_data_30 + src_data_40) - + 0.53333333333f * (src_data_50 + src_data_60); + float t31 = -0.3f * (src_data_11 + src_data_21) + 1.33333333333333f * (src_data_31 + src_data_41) - + 0.53333333333f * (src_data_51 + src_data_61); + float t32 = -0.3f * (src_data_12 + src_data_22) + 1.33333333333333f * (src_data_32 + src_data_42) - + 0.53333333333f * (src_data_52 + src_data_62); + float t33 = -0.3f * (src_data_13 + src_data_23) + 1.33333333333333f * (src_data_33 + src_data_43) - + 0.53333333333f * (src_data_53 + src_data_63); + float t34 = -0.3f * (src_data_14 + src_data_24) + 1.33333333333333f * (src_data_34 + src_data_44) - + 0.53333333333f * (src_data_54 + src_data_64); + float t35 = -0.3f * (src_data_15 + src_data_25) + 1.33333333333333f * (src_data_35 + src_data_45) - + 0.53333333333f * (src_data_55 + src_data_65); + float t36 = -0.3f * (src_data_16 + src_data_26) + 1.33333333333333f * (src_data_36 + src_data_46) - + 0.53333333333f * (src_data_56 + src_data_66); + float t37 = -0.3f * (src_data_17 + src_data_27) + 1.33333333333333f * (src_data_37 + src_data_47) - + 0.53333333333f * (src_data_57 + src_data_67); + + float t40 = 0.3f * (src_data_10 - src_data_20) + 1.33333333333333f * (src_data_40 - src_data_30) + + 0.53333333333f * (src_data_50 - src_data_60); + float t41 = 0.3f * (src_data_11 - src_data_21) + 1.33333333333333f * (src_data_41 - src_data_31) + + 0.53333333333f * (src_data_51 - src_data_61); + float t42 = 0.3f * (src_data_12 - src_data_22) + 1.33333333333333f * (src_data_42 - src_data_32) + + 0.53333333333f * (src_data_52 - src_data_62); + float t43 = 0.3f * (src_data_13 - src_data_23) + 1.33333333333333f * (src_data_43 - src_data_33) + + 0.53333333333f * (src_data_53 - src_data_63); + float t44 = 0.3f * (src_data_14 - src_data_24) + 1.33333333333333f * (src_data_44 - src_data_34) + + 0.53333333333f * (src_data_54 - src_data_64); + float t45 = 0.3f * (src_data_15 - src_data_25) + 1.33333333333333f * (src_data_45 - src_data_35) + + 0.53333333333f * (src_data_55 - src_data_65); + float t46 = 0.3f * (src_data_16 - src_data_26) + 1.33333333333333f * (src_data_46 - src_data_36) + + 0.53333333333f * (src_data_56 - src_data_66); + float t47 = 0.3f * (src_data_17 - src_data_27) + 1.33333333333333f * (src_data_47 - src_data_37) + + 0.53333333333f * (src_data_57 - src_data_67); + + float t50 = 0.0333333333f * src_data_10 + 0.02222222f * src_data_20 - 0.1666666666f * src_data_30 - + 0.1111111111f * src_data_40 + 0.1333333f * src_data_50 + 0.0888888f * src_data_60; + float t51 = 0.0333333333f * src_data_11 + 0.02222222f * src_data_21 - 0.1666666666f * src_data_31 - + 0.1111111111f * src_data_41 + 0.1333333f * src_data_51 + 0.0888888f * src_data_61; + float t52 = 0.0333333333f * src_data_12 + 0.02222222f * src_data_22 - 0.1666666666f * src_data_32 - + 0.1111111111f * src_data_42 + 0.1333333f * src_data_52 + 0.0888888f * src_data_62; + float t53 = 0.0333333333f * src_data_13 + 0.02222222f * src_data_23 - 0.1666666666f * src_data_33 - + 0.1111111111f * src_data_43 + 0.1333333f * src_data_53 + 0.0888888f * src_data_63; + float t54 = 0.0333333333f * src_data_14 + 0.02222222f * src_data_24 - 0.1666666666f * src_data_34 - + 0.1111111111f * src_data_44 + 0.1333333f * src_data_54 + 0.0888888f * src_data_64; + float t55 = 0.0333333333f * src_data_15 + 0.02222222f * src_data_25 - 0.1666666666f * src_data_35 - + 0.1111111111f * src_data_45 + 0.1333333f * src_data_55 + 0.0888888f * src_data_65; + float t56 = 0.0333333333f * src_data_16 + 0.02222222f * src_data_26 - 0.1666666666f * src_data_36 - + 0.1111111111f * src_data_46 + 0.1333333f * src_data_56 + 0.0888888f * src_data_66; + float t57 = 0.0333333333f * src_data_17 + 0.02222222f * src_data_27 - 0.1666666666f * src_data_37 - + 0.1111111111f * src_data_47 + 0.1333333f * src_data_57 + 0.0888888f * src_data_67; + + float t60 = -0.0333333333f * src_data_10 + 0.02222222f * src_data_20 + 0.1666666666f * src_data_30 - + 0.1111111111f * src_data_40 - 0.1333333f * src_data_50 + 0.0888888f * src_data_60; + float t61 = -0.0333333333f * src_data_11 + 0.02222222f * src_data_21 + 0.1666666666f * src_data_31 - + 0.1111111111f * src_data_41 - 0.1333333f * src_data_51 + 0.0888888f * src_data_61; + float t62 = -0.0333333333f * src_data_12 + 0.02222222f * src_data_22 + 0.1666666666f * src_data_32 - + 0.1111111111f * src_data_42 - 0.1333333f * src_data_52 + 0.0888888f * src_data_62; + float t63 = -0.0333333333f * src_data_13 + 0.02222222f * src_data_23 + 0.1666666666f * src_data_33 - + 0.1111111111f * src_data_43 - 0.1333333f * src_data_53 + 0.0888888f * src_data_63; + float t64 = -0.0333333333f * src_data_14 + 0.02222222f * src_data_24 + 0.1666666666f * src_data_34 - + 0.1111111111f * src_data_44 - 0.1333333f * src_data_54 + 0.0888888f * src_data_64; + float t65 = -0.0333333333f * src_data_15 + 0.02222222f * src_data_25 + 0.1666666666f * src_data_35 - + 0.1111111111f * src_data_45 - 0.1333333f * src_data_55 + 0.0888888f * src_data_65; + float t66 = -0.0333333333f * src_data_16 + 0.02222222f * src_data_26 + 0.1666666666f * src_data_36 - + 0.1111111111f * src_data_46 - 0.1333333f * src_data_56 + 0.0888888f * src_data_66; + float t67 = -0.0333333333f * src_data_17 + 0.02222222f * src_data_27 + 0.1666666666f * src_data_37 - + 0.1111111111f * src_data_47 - 0.1333333f * src_data_57 + 0.0888888f * src_data_67; + + float t70 = -0.5625f * src_data_10 + 3.0625f * src_data_30 - 3.5f * src_data_50 + src_data_70; + float t71 = -0.5625f * src_data_11 + 3.0625f * src_data_31 - 3.5f * src_data_51 + src_data_71; + float t72 = -0.5625f * src_data_12 + 3.0625f * src_data_32 - 3.5f * src_data_52 + src_data_72; + float t73 = -0.5625f * src_data_13 + 3.0625f * src_data_33 - 3.5f * src_data_53 + src_data_73; + float t74 = -0.5625f * src_data_14 + 3.0625f * src_data_34 - 3.5f * src_data_54 + src_data_74; + float t75 = -0.5625f * src_data_15 + 3.0625f * src_data_35 - 3.5f * src_data_55 + src_data_75; + float t76 = -0.5625f * src_data_16 + 3.0625f * src_data_36 - 3.5f * src_data_56 + src_data_76; + float t77 = -0.5625f * src_data_17 + 3.0625f * src_data_37 - 3.5f * src_data_57 + src_data_77; + + float m00 = t00 - 5.444444444444444445125f * t02 + 6.222222222222222222223f * t04 - 1.77777777777777778f * t06; + float m01 = 1.5f * t01 + 3.0f * t02 - 2.1666666666666667f * t03 - 4.333333333333333333f * t04 + + 0.66666666666666667f * t05 + 1.333333333333333f * t06; + float m02 = -1.5f * t01 + 3.0f * t02 + 2.1666666666666667f * t03 - 4.333333333333333333f * t04 - + 0.66666666666666667f * t05 + 1.333333333333333f * t06; + float m03 = -0.3f * (t01 + t02) + 1.33333333333333f * (t03 + t04) - 0.53333333333f * (t05 + t06); + float m04 = 0.3f * (t01 - t02) + 1.33333333333333f * (t04 - t03) + 0.53333333333f * (t05 - t06); + float m05 = 0.0333333333f * t01 + 0.02222222f * t02 - 0.1666666666f * t03 - 0.1111111111f * t04 + 0.1333333f * t05 + + 0.0888888f * t06; + float m06 = -0.0333333333f * t01 + 0.02222222f * t02 + 0.1666666666f * t03 - 0.1111111111f * t04 - + 0.1333333f * t05 + 0.0888888f * t06; + float m07 = -0.5625f * t01 + 3.0625f * t03 - 3.5f * t05 + t07; + + float m10 = t10 - 5.444444444444444445125f * t12 + 6.222222222222222222223f * t14 - 1.77777777777777778f * t16; + float m11 = 1.5f * t11 + 3.0f * t12 - 2.1666666666666667f * t13 - 4.333333333333333333f * t14 + + 0.66666666666666667f * t15 + 1.333333333333333f * t16; + float m12 = -1.5f * t11 + 3.0f * t12 + 2.1666666666666667f * t13 - 4.333333333333333333f * t14 - + 0.66666666666666667f * t15 + 1.333333333333333f * t16; + float m13 = -0.3f * (t11 + t12) + 1.33333333333333f * (t13 + t14) - 0.53333333333f * (t15 + t16); + float m14 = 0.3f * (t11 - t12) + 1.33333333333333f * (t14 - t13) + 0.53333333333f * (t15 - t16); + float m15 = 0.0333333333f * t11 + 0.02222222f * t12 - 0.1666666666f * t13 - 0.1111111111f * t14 + 0.1333333f * t15 + + 0.0888888f * t16; + float m16 = -0.0333333333f * t11 + 0.02222222f * t12 + 0.1666666666f * t13 - 0.1111111111f * t14 - + 0.1333333f * t15 + 0.0888888f * t16; + float m17 = -0.5625f * t11 + 3.0625f * t13 - 3.5f * t15 + t17; + + float m20 = t20 - 5.444444444444444445125f * t22 + 6.222222222222222222223f * t24 - 1.77777777777777778f * t26; + float m21 = 1.5f * t21 + 3.0f * t22 - 2.1666666666666667f * t23 - 4.333333333333333333f * t24 + + 0.66666666666666667f * t25 + 1.333333333333333f * t26; + float m22 = -1.5f * t21 + 3.0f * t22 + 2.1666666666666667f * t23 - 4.333333333333333333f * t24 - + 0.66666666666666667f * t25 + 1.333333333333333f * t26; + float m23 = -0.3f * (t21 + t22) + 1.33333333333333f * (t23 + t24) - 0.53333333333f * (t25 + t26); + float m24 = 0.3f * (t21 - t22) + 1.33333333333333f * (t24 - t23) + 0.53333333333f * (t25 - t26); + float m25 = 0.0333333333f * t21 + 0.02222222f * t22 - 0.1666666666f * t23 - 0.1111111111f * t24 + 0.1333333f * t25 + + 0.0888888f * t26; + float m26 = -0.0333333333f * t21 + 0.02222222f * t22 + 0.1666666666f * t23 - 0.1111111111f * t24 - + 0.1333333f * t25 + 0.0888888f * t26; + float m27 = -0.5625f * t21 + 3.0625f * t23 - 3.5f * t25 + t27; + + float m30 = t30 - 5.444444444444444445125f * t32 + 6.222222222222222222223f * t34 - 1.77777777777777778f * t36; + float m31 = 1.5f * t31 + 3.0f * t32 - 2.1666666666666667f * t33 - 4.333333333333333333f * t34 + + 0.66666666666666667f * t35 + 1.333333333333333f * t36; + float m32 = -1.5f * t31 + 3.0f * t32 + 2.1666666666666667f * t33 - 4.333333333333333333f * t34 - + 0.66666666666666667f * t35 + 1.333333333333333f * t36; + float m33 = -0.3f * (t31 + t32) + 1.33333333333333f * (t33 + t34) - 0.53333333333f * (t35 + t36); + float m34 = 0.3f * (t31 - t32) + 1.33333333333333f * (t34 - t33) + 0.53333333333f * (t35 - t36); + float m35 = 0.0333333333f * t31 + 0.02222222f * t32 - 0.1666666666f * t33 - 0.1111111111f * t34 + 0.1333333f * t35 + + 0.0888888f * t36; + float m36 = -0.0333333333f * t31 + 0.02222222f * t32 + 0.1666666666f * t33 - 0.1111111111f * t34 - + 0.1333333f * t35 + 0.0888888f * t36; + float m37 = -0.5625f * t31 + 3.0625f * t33 - 3.5f * t35 + t37; + + float m40 = t40 - 5.444444444444444445125f * t42 + 6.222222222222222222223f * t44 - 1.77777777777777778f * t46; + float m41 = 1.5f * t41 + 3.0f * t42 - 2.1666666666666667f * t43 - 4.333333333333333333f * t44 + + 0.66666666666666667f * t45 + 1.333333333333333f * t46; + float m42 = -1.5f * t41 + 3.0f * t42 + 2.1666666666666667f * t43 - 4.333333333333333333f * t44 - + 0.66666666666666667f * t45 + 1.333333333333333f * t46; + float m43 = -0.3f * (t41 + t42) + 1.33333333333333f * (t43 + t44) - 0.53333333333f * (t45 + t46); + float m44 = 0.3f * (t41 - t42) + 1.33333333333333f * (t44 - t43) + 0.53333333333f * (t45 - t46); + float m45 = 0.0333333333f * t41 + 0.02222222f * t42 - 0.1666666666f * t43 - 0.1111111111f * t44 + 0.1333333f * t45 + + 0.0888888f * t46; + float m46 = -0.0333333333f * t41 + 0.02222222f * t42 + 0.1666666666f * t43 - 0.1111111111f * t44 - + 0.1333333f * t45 + 0.0888888f * t46; + float m47 = -0.5625f * t41 + 3.0625f * t43 - 3.5f * t45 + t47; + + float m50 = t50 - 5.444444444444444445125f * t52 + 6.222222222222222222223f * t54 - 1.77777777777777778f * t56; + float m51 = 1.5f * t51 + 3.0f * t52 - 2.1666666666666667f * t53 - 4.333333333333333333f * t54 + + 0.66666666666666667f * t55 + 1.333333333333333f * t56; + float m52 = -1.5f * t51 + 3.0f * t52 + 2.1666666666666667f * t53 - 4.333333333333333333f * t54 - + 0.66666666666666667f * t55 + 1.333333333333333f * t56; + float m53 = -0.3f * (t51 + t52) + 1.33333333333333f * (t53 + t54) - 0.53333333333f * (t55 + t56); + float m54 = 0.3f * (t51 - t52) + 1.33333333333333f * (t54 - t53) + 0.53333333333f * (t55 - t56); + float m55 = 0.0333333333f * t51 + 0.02222222f * t52 - 0.1666666666f * t53 - 0.1111111111f * t54 + 0.1333333f * t55 + + 0.0888888f * t56; + float m56 = -0.0333333333f * t51 + 0.02222222f * t52 + 0.1666666666f * t53 - 0.1111111111f * t54 - + 0.1333333f * t55 + 0.0888888f * t56; + float m57 = -0.5625f * t51 + 3.0625f * t53 - 3.5f * t55 + t57; + + float m60 = t60 - 5.444444444444444445125f * t62 + 6.222222222222222222223f * t64 - 1.77777777777777778f * t66; + float m61 = 1.5f * t61 + 3.0f * t62 - 2.1666666666666667f * t63 - 4.333333333333333333f * t64 + + 0.66666666666666667f * t65 + 1.333333333333333f * t66; + float m62 = -1.5f * t61 + 3.0f * t62 + 2.1666666666666667f * t63 - 4.333333333333333333f * t64 - + 0.66666666666666667f * t65 + 1.333333333333333f * t66; + float m63 = -0.3f * (t61 + t62) + 1.33333333333333f * (t63 + t64) - 0.53333333333f * (t65 + t66); + float m64 = 0.3f * (t61 - t62) + 1.33333333333333f * (t64 - t63) + 0.53333333333f * (t65 - t66); + float m65 = 0.0333333333f * t61 + 0.02222222f * t62 - 0.1666666666f * t63 - 0.1111111111f * t64 + 0.1333333f * t65 + + 0.0888888f * t66; + float m66 = -0.0333333333f * t61 + 0.02222222f * t62 + 0.1666666666f * t63 - 0.1111111111f * t64 - + 0.1333333f * t65 + 0.0888888f * t66; + float m67 = -0.5625f * t61 + 3.0625f * t63 - 3.5f * t65 + t67; + + float m70 = t70 - 5.444444444444444445125f * t72 + 6.222222222222222222223f * t74 - 1.77777777777777778f * t76; + float m71 = 1.5f * t71 + 3.0f * t72 - 2.1666666666666667f * t73 - 4.333333333333333333f * t74 + + 0.66666666666666667f * t75 + 1.333333333333333f * t76; + float m72 = -1.5f * t71 + 3.0f * t72 + 2.1666666666666667f * t73 - 4.333333333333333333f * t74 - + 0.66666666666666667f * t75 + 1.333333333333333f * t76; + float m73 = -0.3f * (t71 + t72) + 1.33333333333333f * (t73 + t74) - 0.53333333333f * (t75 + t76); + float m74 = 0.3f * (t71 - t72) + 1.33333333333333f * (t74 - t73) + 0.53333333333f * (t75 - t76); + float m75 = 0.0333333333f * t71 + 0.02222222f * t72 - 0.1666666666f * t73 - 0.1111111111f * t74 + 0.1333333f * t75 + + 0.0888888f * t76; + float m76 = -0.0333333333f * t71 + 0.02222222f * t72 + 0.1666666666f * t73 - 0.1111111111f * t74 - + 0.1333333f * t75 + 0.0888888f * t76; + float m77 = -0.5625f * t71 + 3.0625f * t73 - 3.5f * t75 + t77; + + (dst_data + i)[0] = m00; + (dst_data + i + dst_step)[0] = m01; + (dst_data + i + 2 * dst_step)[0] = m02; + (dst_data + i + 3 * dst_step)[0] = m03; + (dst_data + i + 4 * dst_step)[0] = m04; + (dst_data + i + 5 * dst_step)[0] = m05; + (dst_data + i + 6 * dst_step)[0] = m06; + (dst_data + i + 7 * dst_step)[0] = m07; + + (dst_data + i + 8 * dst_step)[0] = m10; + (dst_data + i + 9 * dst_step)[0] = m11; + (dst_data + i + 10 * dst_step)[0] = m12; + (dst_data + i + 11 * dst_step)[0] = m13; + (dst_data + i + 12 * dst_step)[0] = m14; + (dst_data + i + 13 * dst_step)[0] = m15; + (dst_data + i + 14 * dst_step)[0] = m16; + (dst_data + i + 15 * dst_step)[0] = m17; + + (dst_data + i + 16 * dst_step)[0] = m20; + (dst_data + i + 17 * dst_step)[0] = m21; + (dst_data + i + 18 * dst_step)[0] = m22; + (dst_data + i + 19 * dst_step)[0] = m23; + (dst_data + i + 20 * dst_step)[0] = m24; + (dst_data + i + 21 * dst_step)[0] = m25; + (dst_data + i + 22 * dst_step)[0] = m26; + (dst_data + i + 23 * dst_step)[0] = m27; + + (dst_data + i + 24 * dst_step)[0] = m30; + (dst_data + i + 25 * dst_step)[0] = m31; + (dst_data + i + 26 * dst_step)[0] = m32; + (dst_data + i + 27 * dst_step)[0] = m33; + (dst_data + i + 28 * dst_step)[0] = m34; + (dst_data + i + 29 * dst_step)[0] = m35; + (dst_data + i + 30 * dst_step)[0] = m36; + (dst_data + i + 31 * dst_step)[0] = m37; + + (dst_data + i + 32 * dst_step)[0] = m40; + (dst_data + i + 33 * dst_step)[0] = m41; + (dst_data + i + 34 * dst_step)[0] = m42; + (dst_data + i + 35 * dst_step)[0] = m43; + (dst_data + i + 36 * dst_step)[0] = m44; + (dst_data + i + 37 * dst_step)[0] = m45; + (dst_data + i + 38 * dst_step)[0] = m46; + (dst_data + i + 39 * dst_step)[0] = m47; + + (dst_data + i + 40 * dst_step)[0] = m50; + (dst_data + i + 41 * dst_step)[0] = m51; + (dst_data + i + 42 * dst_step)[0] = m52; + (dst_data + i + 43 * dst_step)[0] = m53; + (dst_data + i + 44 * dst_step)[0] = m54; + (dst_data + i + 45 * dst_step)[0] = m55; + (dst_data + i + 46 * dst_step)[0] = m56; + (dst_data + i + 47 * dst_step)[0] = m57; + + (dst_data + i + 48 * dst_step)[0] = m60; + (dst_data + i + 49 * dst_step)[0] = m61; + (dst_data + i + 50 * dst_step)[0] = m62; + (dst_data + i + 51 * dst_step)[0] = m63; + (dst_data + i + 52 * dst_step)[0] = m64; + (dst_data + i + 53 * dst_step)[0] = m65; + (dst_data + i + 54 * dst_step)[0] = m66; + (dst_data + i + 55 * dst_step)[0] = m67; + + (dst_data + i + 56 * dst_step)[0] = m70; + (dst_data + i + 57 * dst_step)[0] = m71; + (dst_data + i + 58 * dst_step)[0] = m72; + (dst_data + i + 59 * dst_step)[0] = m73; + (dst_data + i + 60 * dst_step)[0] = m74; + (dst_data + i + 61 * dst_step)[0] = m75; + (dst_data + i + 62 * dst_step)[0] = m76; + (dst_data + i + 63 * dst_step)[0] = m77; + } +#endif +} + +void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM + float32x4_t bias_ptr = vld1q_f32(bias_data); + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 15 * src_step); + + float32x4_t t00 = vaddq_f32(src_data_00, vaddq_f32(src_data_10, src_data_20)); + float32x4_t t01 = vaddq_f32(src_data_01, vaddq_f32(src_data_11, src_data_21)); + float32x4_t t02 = vaddq_f32(src_data_02, vaddq_f32(src_data_12, src_data_22)); + float32x4_t t03 = vaddq_f32(src_data_03, vaddq_f32(src_data_13, src_data_23)); + + float32x4_t t10 = vsubq_f32(src_data_30, vmulq_n_f32(vsubq_f32(src_data_10, src_data_20), 0.5)); + float32x4_t t11 = vsubq_f32(src_data_31, vmulq_n_f32(vsubq_f32(src_data_11, src_data_21), 0.5)); + float32x4_t t12 = vsubq_f32(src_data_32, vmulq_n_f32(vsubq_f32(src_data_12, src_data_22), 0.5)); + float32x4_t t13 = vsubq_f32(src_data_33, vmulq_n_f32(vsubq_f32(src_data_13, src_data_23), 0.5)); + + float32x4_t m00 = vaddq_f32(vaddq_f32(t00, vaddq_f32(t01, t02)), bias_ptr); + float32x4_t m01 = vaddq_f32(vaddq_f32(t03, vmulq_n_f32(vsubq_f32(t01, t02), 0.5)), bias_ptr); + float32x4_t m10 = vaddq_f32(vaddq_f32(t10, vaddq_f32(t11, t12)), bias_ptr); + float32x4_t m11 = vaddq_f32(vaddq_f32(t13, vmulq_n_f32(vsubq_f32(t11, t12), 0.5)), bias_ptr); + + vst1q_f32(dst_data, m00); + vst1q_f32(dst_data + C4NUM, m01); + vst1q_f32(dst_data + dst_step * C4NUM, m10); + vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m11); +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_10 = src_data[i + 4 * src_step]; + float src_data_11 = src_data[i + 5 * src_step]; + float src_data_12 = src_data[i + 6 * src_step]; + float src_data_13 = src_data[i + 7 * src_step]; + float src_data_20 = src_data[i + 8 * src_step]; + float src_data_21 = src_data[i + 9 * src_step]; + float src_data_22 = src_data[i + 10 * src_step]; + float src_data_23 = src_data[i + 11 * src_step]; + float src_data_30 = src_data[i + 12 * src_step]; + float src_data_31 = src_data[i + 13 * src_step]; + float src_data_32 = src_data[i + 14 * src_step]; + float src_data_33 = src_data[i + 15 * src_step]; + + float t00 = src_data_00 + src_data_10 + src_data_20; + float t01 = src_data_01 + src_data_11 + src_data_21; + float t02 = src_data_02 + src_data_12 + src_data_22; + float t03 = src_data_03 + src_data_13 + src_data_23; + + float t10 = 0.5f * (src_data_10 - src_data_20) + src_data_30; + float t11 = 0.5f * (src_data_11 - src_data_21) + src_data_31; + float t12 = 0.5f * (src_data_12 - src_data_22) + src_data_32; + float t13 = 0.5f * (src_data_13 - src_data_23) + src_data_33; + + float m00 = t00 + t01 + t02 + bias_data[i]; + float m01 = 0.5f * (t01 - t02) + t03 + bias_data[i]; + float m10 = t10 + t11 + t12 + bias_data[i]; + float m11 = 0.5f * (t11 - t12) + t13 + bias_data[i]; + + (dst_data + i)[0] = m00; + (dst_data + i + C4NUM)[0] = m01; + (dst_data + i + dst_step * C4NUM)[0] = m10; + (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11; + } +#endif +} + +void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM + float32x4_t bias_ptr = vld1q_f32(bias_data); + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 15 * src_step); + + float32x4_t t00 = vaddq_f32(src_data_00, vaddq_f32(src_data_10, src_data_20)); + float32x4_t t01 = vaddq_f32(src_data_01, vaddq_f32(src_data_11, src_data_21)); + float32x4_t t02 = vaddq_f32(src_data_02, vaddq_f32(src_data_12, src_data_22)); + float32x4_t t03 = vaddq_f32(src_data_03, vaddq_f32(src_data_13, src_data_23)); + + float32x4_t t10 = vmulq_n_f32(vsubq_f32(src_data_10, src_data_20), 0.5); + float32x4_t t11 = vmulq_n_f32(vsubq_f32(src_data_11, src_data_21), 0.5); + float32x4_t t12 = vmulq_n_f32(vsubq_f32(src_data_12, src_data_22), 0.5); + float32x4_t t13 = vmulq_n_f32(vsubq_f32(src_data_13, src_data_23), 0.5); + + float32x4_t t20 = vaddq_f32(src_data_30, vmulq_n_f32(vaddq_f32(src_data_10, src_data_20), 0.25)); + float32x4_t t21 = vaddq_f32(src_data_31, vmulq_n_f32(vaddq_f32(src_data_11, src_data_21), 0.25)); + float32x4_t t22 = vaddq_f32(src_data_32, vmulq_n_f32(vaddq_f32(src_data_12, src_data_22), 0.25)); + float32x4_t t23 = vaddq_f32(src_data_33, vmulq_n_f32(vaddq_f32(src_data_13, src_data_23), 0.25)); + + float32x4_t m00 = vaddq_f32(vaddq_f32(t00, vaddq_f32(t01, t02)), bias_ptr); + float32x4_t m01 = vaddq_f32(vmulq_n_f32(vsubq_f32(t01, t02), 0.5), bias_ptr); + float32x4_t m02 = vaddq_f32(vaddq_f32(t03, vmulq_n_f32(vaddq_f32(t01, t02), 0.25)), bias_ptr); + float32x4_t m10 = vaddq_f32(vaddq_f32(t10, vaddq_f32(t11, t12)), bias_ptr); + float32x4_t m11 = vaddq_f32(vmulq_n_f32(vsubq_f32(t11, t12), 0.5), bias_ptr); + float32x4_t m12 = vaddq_f32(vaddq_f32(t13, vmulq_n_f32(vaddq_f32(t11, t12), 0.25)), bias_ptr); + float32x4_t m20 = vaddq_f32(vaddq_f32(t20, vaddq_f32(t21, t22)), bias_ptr); + float32x4_t m21 = vaddq_f32(vmulq_n_f32(vsubq_f32(t21, t22), 0.5), bias_ptr); + float32x4_t m22 = vaddq_f32(vaddq_f32(t23, vmulq_n_f32(vaddq_f32(t21, t22), 0.25)), bias_ptr); + + vst1q_f32(dst_data, m00); + vst1q_f32(dst_data + C4NUM, m01); + vst1q_f32(dst_data + 2 * C4NUM, m02); + vst1q_f32(dst_data + dst_step * C4NUM, m10); + vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m11); + vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, m12); + vst1q_f32(dst_data + 2 * dst_step * C4NUM, m20); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, m21); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, m22); +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_10 = src_data[i + 4 * src_step]; + float src_data_11 = src_data[i + 5 * src_step]; + float src_data_12 = src_data[i + 6 * src_step]; + float src_data_13 = src_data[i + 7 * src_step]; + float src_data_20 = src_data[i + 8 * src_step]; + float src_data_21 = src_data[i + 9 * src_step]; + float src_data_22 = src_data[i + 10 * src_step]; + float src_data_23 = src_data[i + 11 * src_step]; + float src_data_30 = src_data[i + 12 * src_step]; + float src_data_31 = src_data[i + 13 * src_step]; + float src_data_32 = src_data[i + 14 * src_step]; + float src_data_33 = src_data[i + 15 * src_step]; + + float t00 = src_data_00 + src_data_10 + src_data_20; + float t01 = src_data_01 + src_data_11 + src_data_21; + float t02 = src_data_02 + src_data_12 + src_data_22; + float t03 = src_data_03 + src_data_13 + src_data_23; + + float t10 = 0.5f * (src_data_10 - src_data_20); + float t11 = 0.5f * (src_data_11 - src_data_21); + float t12 = 0.5f * (src_data_12 - src_data_22); + float t13 = 0.5f * (src_data_13 - src_data_23); + + float t20 = 0.25f * (src_data_10 + src_data_20) + src_data_30; + float t21 = 0.25f * (src_data_11 + src_data_21) + src_data_31; + float t22 = 0.25f * (src_data_12 + src_data_22) + src_data_32; + float t23 = 0.25f * (src_data_13 + src_data_23) + src_data_33; + + float m00 = t00 + t01 + t02 + bias_data[i]; + float m01 = 0.5f * (t01 - t02) + bias_data[i]; + float m02 = 0.25f * (t01 + t02) + t03 + bias_data[i]; + + float m10 = t10 + t11 + t12 + bias_data[i]; + float m11 = 0.5f * (t11 - t12) + bias_data[i]; + float m12 = 0.25f * (t11 + t12) + t13 + bias_data[i]; + + float m20 = t20 + t21 + t22 + bias_data[i]; + float m21 = 0.5f * (t21 - t22) + bias_data[i]; + float m22 = 0.25f * (t21 + t22) + t23 + bias_data[i]; + + (dst_data + i)[0] = m00; + (dst_data + i + C4NUM)[0] = m01; + (dst_data + i + 2 * C4NUM)[0] = m02; + + (dst_data + i + dst_step * C4NUM)[0] = m10; + (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11; + (dst_data + i + dst_step * C4NUM + 2 * C4NUM)[0] = m12; + + (dst_data + i + 2 * dst_step * C4NUM)[0] = m20; + (dst_data + i + 2 * dst_step * C4NUM + C4NUM)[0] = m21; + (dst_data + i + 2 * dst_step * C4NUM + 2 * C4NUM)[0] = m22; + } +#endif +} + +void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_04 = src_data[i + 4 * src_step]; + float src_data_05 = src_data[i + 5 * src_step]; + float src_data_06 = src_data[i + 6 * src_step]; + float src_data_07 = src_data[i + 7 * src_step]; + float src_data_10 = src_data[i + 8 * src_step]; + float src_data_11 = src_data[i + 9 * src_step]; + float src_data_12 = src_data[i + 10 * src_step]; + float src_data_13 = src_data[i + 11 * src_step]; + float src_data_14 = src_data[i + 12 * src_step]; + float src_data_15 = src_data[i + 13 * src_step]; + float src_data_16 = src_data[i + 14 * src_step]; + float src_data_17 = src_data[i + 15 * src_step]; + float src_data_20 = src_data[i + 16 * src_step]; + float src_data_21 = src_data[i + 17 * src_step]; + float src_data_22 = src_data[i + 18 * src_step]; + float src_data_23 = src_data[i + 19 * src_step]; + float src_data_24 = src_data[i + 20 * src_step]; + float src_data_25 = src_data[i + 21 * src_step]; + float src_data_26 = src_data[i + 22 * src_step]; + float src_data_27 = src_data[i + 23 * src_step]; + float src_data_30 = src_data[i + 24 * src_step]; + float src_data_31 = src_data[i + 25 * src_step]; + float src_data_32 = src_data[i + 26 * src_step]; + float src_data_33 = src_data[i + 27 * src_step]; + float src_data_34 = src_data[i + 28 * src_step]; + float src_data_35 = src_data[i + 29 * src_step]; + float src_data_36 = src_data[i + 30 * src_step]; + float src_data_37 = src_data[i + 31 * src_step]; + float src_data_40 = src_data[i + 32 * src_step]; + float src_data_41 = src_data[i + 33 * src_step]; + float src_data_42 = src_data[i + 34 * src_step]; + float src_data_43 = src_data[i + 35 * src_step]; + float src_data_44 = src_data[i + 36 * src_step]; + float src_data_45 = src_data[i + 37 * src_step]; + float src_data_46 = src_data[i + 38 * src_step]; + float src_data_47 = src_data[i + 39 * src_step]; + float src_data_50 = src_data[i + 40 * src_step]; + float src_data_51 = src_data[i + 41 * src_step]; + float src_data_52 = src_data[i + 42 * src_step]; + float src_data_53 = src_data[i + 43 * src_step]; + float src_data_54 = src_data[i + 44 * src_step]; + float src_data_55 = src_data[i + 45 * src_step]; + float src_data_56 = src_data[i + 46 * src_step]; + float src_data_57 = src_data[i + 47 * src_step]; + float src_data_60 = src_data[i + 48 * src_step]; + float src_data_61 = src_data[i + 49 * src_step]; + float src_data_62 = src_data[i + 50 * src_step]; + float src_data_63 = src_data[i + 51 * src_step]; + float src_data_64 = src_data[i + 52 * src_step]; + float src_data_65 = src_data[i + 53 * src_step]; + float src_data_66 = src_data[i + 54 * src_step]; + float src_data_67 = src_data[i + 55 * src_step]; + float src_data_70 = src_data[i + 56 * src_step]; + float src_data_71 = src_data[i + 57 * src_step]; + float src_data_72 = src_data[i + 58 * src_step]; + float src_data_73 = src_data[i + 59 * src_step]; + float src_data_74 = src_data[i + 60 * src_step]; + float src_data_75 = src_data[i + 61 * src_step]; + float src_data_76 = src_data[i + 62 * src_step]; + float src_data_77 = src_data[i + 63 * src_step]; + + float d01 = src_data_10 - src_data_20; + float d02 = src_data_11 - src_data_21; + float d03 = src_data_12 - src_data_22; + float d04 = src_data_13 - src_data_23; + float d05 = src_data_14 - src_data_24; + float d06 = src_data_15 - src_data_25; + float d07 = src_data_16 - src_data_26; + float d08 = src_data_17 - src_data_27; + + float d11 = src_data_30 - src_data_40; + float d12 = src_data_31 - src_data_41; + float d13 = src_data_32 - src_data_42; + float d14 = src_data_33 - src_data_43; + float d15 = src_data_34 - src_data_44; + float d16 = src_data_35 - src_data_45; + float d17 = src_data_36 - src_data_46; + float d18 = src_data_37 - src_data_47; + + float d21 = src_data_50 - src_data_60; + float d22 = src_data_51 - src_data_61; + float d23 = src_data_52 - src_data_62; + float d24 = src_data_53 - src_data_63; + float d25 = src_data_54 - src_data_64; + float d26 = src_data_55 - src_data_65; + float d27 = src_data_56 - src_data_66; + float d28 = src_data_57 - src_data_67; + + float d31 = src_data_10 + src_data_20; + float d32 = src_data_11 + src_data_21; + float d33 = src_data_12 + src_data_22; + float d34 = src_data_13 + src_data_23; + float d35 = src_data_14 + src_data_24; + float d36 = src_data_15 + src_data_25; + float d37 = src_data_16 + src_data_26; + float d38 = src_data_17 + src_data_27; + + float d41 = src_data_30 + src_data_40; + float d42 = src_data_31 + src_data_41; + float d43 = src_data_32 + src_data_42; + float d44 = src_data_33 + src_data_43; + float d45 = src_data_34 + src_data_44; + float d46 = src_data_35 + src_data_45; + float d47 = src_data_36 + src_data_46; + float d48 = src_data_37 + src_data_47; + + float d51 = src_data_50 + src_data_60; + float d52 = src_data_51 + src_data_61; + float d53 = src_data_52 + src_data_62; + float d54 = src_data_53 + src_data_63; + float d55 = src_data_54 + src_data_64; + float d56 = src_data_55 + src_data_65; + float d57 = src_data_56 + src_data_66; + float d58 = src_data_57 + src_data_67; + + float t00 = src_data_00 + src_data_10 + src_data_20 + src_data_30 + src_data_40 + src_data_50 + src_data_60; + float t01 = src_data_01 + src_data_11 + src_data_21 + src_data_31 + src_data_41 + src_data_51 + src_data_61; + float t02 = src_data_02 + src_data_12 + src_data_22 + src_data_32 + src_data_42 + src_data_52 + src_data_62; + float t03 = src_data_03 + src_data_13 + src_data_23 + src_data_33 + src_data_43 + src_data_53 + src_data_63; + float t04 = src_data_04 + src_data_14 + src_data_24 + src_data_34 + src_data_44 + src_data_54 + src_data_64; + float t05 = src_data_05 + src_data_15 + src_data_25 + src_data_35 + src_data_45 + src_data_55 + src_data_65; + float t06 = src_data_06 + src_data_16 + src_data_26 + src_data_36 + src_data_46 + src_data_56 + src_data_66; + float t07 = src_data_07 + src_data_17 + src_data_27 + src_data_37 + src_data_47 + src_data_57 + src_data_67; + + float t10 = 0.5f * d01 + d11 + 1.5f * d21 + src_data_70; + float t11 = 0.5f * d02 + d12 + 1.5f * d22 + src_data_71; + float t12 = 0.5f * d03 + d13 + 1.5f * d23 + src_data_72; + float t13 = 0.5f * d04 + d14 + 1.5f * d24 + src_data_73; + float t14 = 0.5f * d05 + d15 + 1.5f * d25 + src_data_74; + float t15 = 0.5f * d06 + d16 + 1.5f * d26 + src_data_75; + float t16 = 0.5f * d07 + d17 + 1.5f * d27 + src_data_76; + float t17 = 0.5f * d08 + d18 + 1.5f * d28 + src_data_77; + + float s11 = t01 - t02; + float s12 = t11 - t12; + + float s21 = t03 - t04; + float s22 = t13 - t14; + + float s31 = t05 - t06; + float s32 = t15 - t16; + + float s41 = t01 + t02; + float s42 = t11 + t12; + + float s51 = t03 + t04; + float s52 = t13 + t14; + + float s61 = t05 + t06; + float s62 = t15 + t16; + + float m00 = t00 + t01 + t02 + t03 + t04 + t05 + t06; + float m01 = 0.5f * s11 + s21 + 1.5f * s31 + t07; + + float m10 = t10 + t11 + t12 + t13 + t14 + t15 + t16; + float m11 = 0.5f * s12 + s22 + 1.5f * s32 + t17; + + (dst_data + i)[0] = m00 + bias_data[i]; + (dst_data + i + C4NUM)[0] = m01 + bias_data[i]; + (dst_data + i + dst_step * C4NUM)[0] = m10 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11 + bias_data[i]; + } +#endif +} + +void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_04 = src_data[i + 4 * src_step]; + float src_data_05 = src_data[i + 5 * src_step]; + float src_data_06 = src_data[i + 6 * src_step]; + float src_data_07 = src_data[i + 7 * src_step]; + float src_data_10 = src_data[i + 8 * src_step]; + float src_data_11 = src_data[i + 9 * src_step]; + float src_data_12 = src_data[i + 10 * src_step]; + float src_data_13 = src_data[i + 11 * src_step]; + float src_data_14 = src_data[i + 12 * src_step]; + float src_data_15 = src_data[i + 13 * src_step]; + float src_data_16 = src_data[i + 14 * src_step]; + float src_data_17 = src_data[i + 15 * src_step]; + float src_data_20 = src_data[i + 16 * src_step]; + float src_data_21 = src_data[i + 17 * src_step]; + float src_data_22 = src_data[i + 18 * src_step]; + float src_data_23 = src_data[i + 19 * src_step]; + float src_data_24 = src_data[i + 20 * src_step]; + float src_data_25 = src_data[i + 21 * src_step]; + float src_data_26 = src_data[i + 22 * src_step]; + float src_data_27 = src_data[i + 23 * src_step]; + float src_data_30 = src_data[i + 24 * src_step]; + float src_data_31 = src_data[i + 25 * src_step]; + float src_data_32 = src_data[i + 26 * src_step]; + float src_data_33 = src_data[i + 27 * src_step]; + float src_data_34 = src_data[i + 28 * src_step]; + float src_data_35 = src_data[i + 29 * src_step]; + float src_data_36 = src_data[i + 30 * src_step]; + float src_data_37 = src_data[i + 31 * src_step]; + float src_data_40 = src_data[i + 32 * src_step]; + float src_data_41 = src_data[i + 33 * src_step]; + float src_data_42 = src_data[i + 34 * src_step]; + float src_data_43 = src_data[i + 35 * src_step]; + float src_data_44 = src_data[i + 36 * src_step]; + float src_data_45 = src_data[i + 37 * src_step]; + float src_data_46 = src_data[i + 38 * src_step]; + float src_data_47 = src_data[i + 39 * src_step]; + float src_data_50 = src_data[i + 40 * src_step]; + float src_data_51 = src_data[i + 41 * src_step]; + float src_data_52 = src_data[i + 42 * src_step]; + float src_data_53 = src_data[i + 43 * src_step]; + float src_data_54 = src_data[i + 44 * src_step]; + float src_data_55 = src_data[i + 45 * src_step]; + float src_data_56 = src_data[i + 46 * src_step]; + float src_data_57 = src_data[i + 47 * src_step]; + float src_data_60 = src_data[i + 48 * src_step]; + float src_data_61 = src_data[i + 49 * src_step]; + float src_data_62 = src_data[i + 50 * src_step]; + float src_data_63 = src_data[i + 51 * src_step]; + float src_data_64 = src_data[i + 52 * src_step]; + float src_data_65 = src_data[i + 53 * src_step]; + float src_data_66 = src_data[i + 54 * src_step]; + float src_data_67 = src_data[i + 55 * src_step]; + float src_data_70 = src_data[i + 56 * src_step]; + float src_data_71 = src_data[i + 57 * src_step]; + float src_data_72 = src_data[i + 58 * src_step]; + float src_data_73 = src_data[i + 59 * src_step]; + float src_data_74 = src_data[i + 60 * src_step]; + float src_data_75 = src_data[i + 61 * src_step]; + float src_data_76 = src_data[i + 62 * src_step]; + float src_data_77 = src_data[i + 63 * src_step]; + + float d01 = src_data_10 - src_data_20; + float d02 = src_data_11 - src_data_21; + float d03 = src_data_12 - src_data_22; + float d04 = src_data_13 - src_data_23; + float d05 = src_data_14 - src_data_24; + float d06 = src_data_15 - src_data_25; + float d07 = src_data_16 - src_data_26; + float d08 = src_data_17 - src_data_27; + + float d11 = src_data_30 - src_data_40; + float d12 = src_data_31 - src_data_41; + float d13 = src_data_32 - src_data_42; + float d14 = src_data_33 - src_data_43; + float d15 = src_data_34 - src_data_44; + float d16 = src_data_35 - src_data_45; + float d17 = src_data_36 - src_data_46; + float d18 = src_data_37 - src_data_47; + + float d21 = src_data_50 - src_data_60; + float d22 = src_data_51 - src_data_61; + float d23 = src_data_52 - src_data_62; + float d24 = src_data_53 - src_data_63; + float d25 = src_data_54 - src_data_64; + float d26 = src_data_55 - src_data_65; + float d27 = src_data_56 - src_data_66; + float d28 = src_data_57 - src_data_67; + + float d31 = src_data_10 + src_data_20; + float d32 = src_data_11 + src_data_21; + float d33 = src_data_12 + src_data_22; + float d34 = src_data_13 + src_data_23; + float d35 = src_data_14 + src_data_24; + float d36 = src_data_15 + src_data_25; + float d37 = src_data_16 + src_data_26; + float d38 = src_data_17 + src_data_27; + + float d41 = src_data_30 + src_data_40; + float d42 = src_data_31 + src_data_41; + float d43 = src_data_32 + src_data_42; + float d44 = src_data_33 + src_data_43; + float d45 = src_data_34 + src_data_44; + float d46 = src_data_35 + src_data_45; + float d47 = src_data_36 + src_data_46; + float d48 = src_data_37 + src_data_47; + + float d51 = src_data_50 + src_data_60; + float d52 = src_data_51 + src_data_61; + float d53 = src_data_52 + src_data_62; + float d54 = src_data_53 + src_data_63; + float d55 = src_data_54 + src_data_64; + float d56 = src_data_55 + src_data_65; + float d57 = src_data_56 + src_data_66; + float d58 = src_data_57 + src_data_67; + + float t00 = src_data_00 + src_data_10 + src_data_20 + src_data_30 + src_data_40 + src_data_50 + src_data_60; + float t01 = src_data_01 + src_data_11 + src_data_21 + src_data_31 + src_data_41 + src_data_51 + src_data_61; + float t02 = src_data_02 + src_data_12 + src_data_22 + src_data_32 + src_data_42 + src_data_52 + src_data_62; + float t03 = src_data_03 + src_data_13 + src_data_23 + src_data_33 + src_data_43 + src_data_53 + src_data_63; + float t04 = src_data_04 + src_data_14 + src_data_24 + src_data_34 + src_data_44 + src_data_54 + src_data_64; + float t05 = src_data_05 + src_data_15 + src_data_25 + src_data_35 + src_data_45 + src_data_55 + src_data_65; + float t06 = src_data_06 + src_data_16 + src_data_26 + src_data_36 + src_data_46 + src_data_56 + src_data_66; + float t07 = src_data_07 + src_data_17 + src_data_27 + src_data_37 + src_data_47 + src_data_57 + src_data_67; + + float t10 = 0.5f * d01 + d11 + 1.5f * d21; + float t11 = 0.5f * d02 + d12 + 1.5f * d22; + float t12 = 0.5f * d03 + d13 + 1.5f * d23; + float t13 = 0.5f * d04 + d14 + 1.5f * d24; + float t14 = 0.5f * d05 + d15 + 1.5f * d25; + float t15 = 0.5f * d06 + d16 + 1.5f * d26; + float t16 = 0.5f * d07 + d17 + 1.5f * d27; + float t17 = 0.5f * d08 + d18 + 1.5f * d28; + + float t20 = 0.25f * d31 + d41 + 2.25f * d51 + src_data_70; + float t21 = 0.25f * d32 + d42 + 2.25f * d52 + src_data_71; + float t22 = 0.25f * d33 + d43 + 2.25f * d53 + src_data_72; + float t23 = 0.25f * d34 + d44 + 2.25f * d54 + src_data_73; + float t24 = 0.25f * d35 + d45 + 2.25f * d55 + src_data_74; + float t25 = 0.25f * d36 + d46 + 2.25f * d56 + src_data_75; + float t26 = 0.25f * d37 + d47 + 2.25f * d57 + src_data_76; + float t27 = 0.25f * d38 + d48 + 2.25f * d58 + src_data_77; + + float s11 = t01 - t02; + float s12 = t11 - t12; + float s13 = t21 - t22; + + float s21 = t03 - t04; + float s22 = t13 - t14; + float s23 = t23 - t24; + + float s31 = t05 - t06; + float s32 = t15 - t16; + float s33 = t25 - t26; + + float s41 = t01 + t02; + float s42 = t11 + t12; + float s43 = t21 + t22; + + float s51 = t03 + t04; + float s52 = t13 + t14; + float s53 = t23 + t24; + + float s61 = t05 + t06; + float s62 = t15 + t16; + float s63 = t25 + t26; + + float m00 = t00 + t01 + t02 + t03 + t04 + t05 + t06; + float m01 = 0.5f * s11 + s21 + 1.5f * s31; + float m02 = 0.25f * s41 + s51 + 2.25f * s61 + t07; + + float m10 = t10 + t11 + t12 + t13 + t14 + t15 + t16; + float m11 = 0.5f * s12 + s22 + 1.5f * s32; + float m12 = 0.25f * s42 + s52 + 2.25f * s62 + t17; + + float m20 = t20 + t21 + t22 + t23 + t24 + t25 + t26; + float m21 = 0.5f * s13 + s23 + 1.5f * s33; + float m22 = 0.25f * s43 + s53 + 2.25f * s63 + t27; + + (dst_data + i)[0] = m00 + bias_data[i]; + (dst_data + i + C4NUM)[0] = m01 + bias_data[i]; + (dst_data + i + 2 * C4NUM)[0] = m02 + bias_data[i]; + + (dst_data + i + dst_step * C4NUM)[0] = m10 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 2 * C4NUM)[0] = m12 + bias_data[i]; + + (dst_data + i + 2 * dst_step * C4NUM)[0] = m20 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + C4NUM)[0] = m21 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 2 * C4NUM)[0] = m22 + bias_data[i]; + } +#endif +} + +void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_04 = src_data[i + 4 * src_step]; + float src_data_05 = src_data[i + 5 * src_step]; + float src_data_06 = src_data[i + 6 * src_step]; + float src_data_07 = src_data[i + 7 * src_step]; + float src_data_10 = src_data[i + 8 * src_step]; + float src_data_11 = src_data[i + 9 * src_step]; + float src_data_12 = src_data[i + 10 * src_step]; + float src_data_13 = src_data[i + 11 * src_step]; + float src_data_14 = src_data[i + 12 * src_step]; + float src_data_15 = src_data[i + 13 * src_step]; + float src_data_16 = src_data[i + 14 * src_step]; + float src_data_17 = src_data[i + 15 * src_step]; + float src_data_20 = src_data[i + 16 * src_step]; + float src_data_21 = src_data[i + 17 * src_step]; + float src_data_22 = src_data[i + 18 * src_step]; + float src_data_23 = src_data[i + 19 * src_step]; + float src_data_24 = src_data[i + 20 * src_step]; + float src_data_25 = src_data[i + 21 * src_step]; + float src_data_26 = src_data[i + 22 * src_step]; + float src_data_27 = src_data[i + 23 * src_step]; + float src_data_30 = src_data[i + 24 * src_step]; + float src_data_31 = src_data[i + 25 * src_step]; + float src_data_32 = src_data[i + 26 * src_step]; + float src_data_33 = src_data[i + 27 * src_step]; + float src_data_34 = src_data[i + 28 * src_step]; + float src_data_35 = src_data[i + 29 * src_step]; + float src_data_36 = src_data[i + 30 * src_step]; + float src_data_37 = src_data[i + 31 * src_step]; + float src_data_40 = src_data[i + 32 * src_step]; + float src_data_41 = src_data[i + 33 * src_step]; + float src_data_42 = src_data[i + 34 * src_step]; + float src_data_43 = src_data[i + 35 * src_step]; + float src_data_44 = src_data[i + 36 * src_step]; + float src_data_45 = src_data[i + 37 * src_step]; + float src_data_46 = src_data[i + 38 * src_step]; + float src_data_47 = src_data[i + 39 * src_step]; + float src_data_50 = src_data[i + 40 * src_step]; + float src_data_51 = src_data[i + 41 * src_step]; + float src_data_52 = src_data[i + 42 * src_step]; + float src_data_53 = src_data[i + 43 * src_step]; + float src_data_54 = src_data[i + 44 * src_step]; + float src_data_55 = src_data[i + 45 * src_step]; + float src_data_56 = src_data[i + 46 * src_step]; + float src_data_57 = src_data[i + 47 * src_step]; + float src_data_60 = src_data[i + 48 * src_step]; + float src_data_61 = src_data[i + 49 * src_step]; + float src_data_62 = src_data[i + 50 * src_step]; + float src_data_63 = src_data[i + 51 * src_step]; + float src_data_64 = src_data[i + 52 * src_step]; + float src_data_65 = src_data[i + 53 * src_step]; + float src_data_66 = src_data[i + 54 * src_step]; + float src_data_67 = src_data[i + 55 * src_step]; + float src_data_70 = src_data[i + 56 * src_step]; + float src_data_71 = src_data[i + 57 * src_step]; + float src_data_72 = src_data[i + 58 * src_step]; + float src_data_73 = src_data[i + 59 * src_step]; + float src_data_74 = src_data[i + 60 * src_step]; + float src_data_75 = src_data[i + 61 * src_step]; + float src_data_76 = src_data[i + 62 * src_step]; + float src_data_77 = src_data[i + 63 * src_step]; + + float d01 = src_data_10 - src_data_20; + float d02 = src_data_11 - src_data_21; + float d03 = src_data_12 - src_data_22; + float d04 = src_data_13 - src_data_23; + float d05 = src_data_14 - src_data_24; + float d06 = src_data_15 - src_data_25; + float d07 = src_data_16 - src_data_26; + float d08 = src_data_17 - src_data_27; + + float d11 = src_data_30 - src_data_40; + float d12 = src_data_31 - src_data_41; + float d13 = src_data_32 - src_data_42; + float d14 = src_data_33 - src_data_43; + float d15 = src_data_34 - src_data_44; + float d16 = src_data_35 - src_data_45; + float d17 = src_data_36 - src_data_46; + float d18 = src_data_37 - src_data_47; + + float d21 = src_data_50 - src_data_60; + float d22 = src_data_51 - src_data_61; + float d23 = src_data_52 - src_data_62; + float d24 = src_data_53 - src_data_63; + float d25 = src_data_54 - src_data_64; + float d26 = src_data_55 - src_data_65; + float d27 = src_data_56 - src_data_66; + float d28 = src_data_57 - src_data_67; + + float d31 = src_data_10 + src_data_20; + float d32 = src_data_11 + src_data_21; + float d33 = src_data_12 + src_data_22; + float d34 = src_data_13 + src_data_23; + float d35 = src_data_14 + src_data_24; + float d36 = src_data_15 + src_data_25; + float d37 = src_data_16 + src_data_26; + float d38 = src_data_17 + src_data_27; + + float d41 = src_data_30 + src_data_40; + float d42 = src_data_31 + src_data_41; + float d43 = src_data_32 + src_data_42; + float d44 = src_data_33 + src_data_43; + float d45 = src_data_34 + src_data_44; + float d46 = src_data_35 + src_data_45; + float d47 = src_data_36 + src_data_46; + float d48 = src_data_37 + src_data_47; + + float d51 = src_data_50 + src_data_60; + float d52 = src_data_51 + src_data_61; + float d53 = src_data_52 + src_data_62; + float d54 = src_data_53 + src_data_63; + float d55 = src_data_54 + src_data_64; + float d56 = src_data_55 + src_data_65; + float d57 = src_data_56 + src_data_66; + float d58 = src_data_57 + src_data_67; + + float t00 = src_data_00 + src_data_10 + src_data_20 + src_data_30 + src_data_40 + src_data_50 + src_data_60; + float t01 = src_data_01 + src_data_11 + src_data_21 + src_data_31 + src_data_41 + src_data_51 + src_data_61; + float t02 = src_data_02 + src_data_12 + src_data_22 + src_data_32 + src_data_42 + src_data_52 + src_data_62; + float t03 = src_data_03 + src_data_13 + src_data_23 + src_data_33 + src_data_43 + src_data_53 + src_data_63; + float t04 = src_data_04 + src_data_14 + src_data_24 + src_data_34 + src_data_44 + src_data_54 + src_data_64; + float t05 = src_data_05 + src_data_15 + src_data_25 + src_data_35 + src_data_45 + src_data_55 + src_data_65; + float t06 = src_data_06 + src_data_16 + src_data_26 + src_data_36 + src_data_46 + src_data_56 + src_data_66; + float t07 = src_data_07 + src_data_17 + src_data_27 + src_data_37 + src_data_47 + src_data_57 + src_data_67; + + float t10 = 0.5f * d01 + d11 + 1.5f * d21; + float t11 = 0.5f * d02 + d12 + 1.5f * d22; + float t12 = 0.5f * d03 + d13 + 1.5f * d23; + float t13 = 0.5f * d04 + d14 + 1.5f * d24; + float t14 = 0.5f * d05 + d15 + 1.5f * d25; + float t15 = 0.5f * d06 + d16 + 1.5f * d26; + float t16 = 0.5f * d07 + d17 + 1.5f * d27; + float t17 = 0.5f * d08 + d18 + 1.5f * d28; + + float t20 = 0.25f * d31 + d41 + 2.25f * d51; + float t21 = 0.25f * d32 + d42 + 2.25f * d52; + float t22 = 0.25f * d33 + d43 + 2.25f * d53; + float t23 = 0.25f * d34 + d44 + 2.25f * d54; + float t24 = 0.25f * d35 + d45 + 2.25f * d55; + float t25 = 0.25f * d36 + d46 + 2.25f * d56; + float t26 = 0.25f * d37 + d47 + 2.25f * d57; + float t27 = 0.25f * d38 + d48 + 2.25f * d58; + + float t30 = 0.125f * d01 + d11 + 3.375f * d21 + src_data_70; + float t31 = 0.125f * d02 + d12 + 3.375f * d22 + src_data_71; + float t32 = 0.125f * d03 + d13 + 3.375f * d23 + src_data_72; + float t33 = 0.125f * d04 + d14 + 3.375f * d24 + src_data_73; + float t34 = 0.125f * d05 + d15 + 3.375f * d25 + src_data_74; + float t35 = 0.125f * d06 + d16 + 3.375f * d26 + src_data_75; + float t36 = 0.125f * d07 + d17 + 3.375f * d27 + src_data_76; + float t37 = 0.125f * d08 + d18 + 3.375f * d28 + src_data_77; + + float s11 = t01 - t02; + float s12 = t11 - t12; + float s13 = t21 - t22; + float s14 = t31 - t32; + + float s21 = t03 - t04; + float s22 = t13 - t14; + float s23 = t23 - t24; + float s24 = t33 - t34; + + float s31 = t05 - t06; + float s32 = t15 - t16; + float s33 = t25 - t26; + float s34 = t35 - t36; + + float s41 = t01 + t02; + float s42 = t11 + t12; + float s43 = t21 + t22; + float s44 = t31 + t32; + + float s51 = t03 + t04; + float s52 = t13 + t14; + float s53 = t23 + t24; + float s54 = t33 + t34; + + float s61 = t05 + t06; + float s62 = t15 + t16; + float s63 = t25 + t26; + float s64 = t35 + t36; + + float m00 = t00 + t01 + t02 + t03 + t04 + t05 + t06; + float m01 = 0.5f * s11 + s21 + 1.5f * s31; + float m02 = 0.25f * s41 + s51 + 2.25f * s61; + float m03 = 0.125f * s11 + s21 + 3.375f * s31 + t07; + + float m10 = t10 + t11 + t12 + t13 + t14 + t15 + t16; + float m11 = 0.5f * s12 + s22 + 1.5f * s32; + float m12 = 0.25f * s42 + s52 + 2.25f * s62; + float m13 = 0.125f * s12 + s22 + 3.375f * s32 + t17; + + float m20 = t20 + t21 + t22 + t23 + t24 + t25 + t26; + float m21 = 0.5f * s13 + s23 + 1.5f * s33; + float m22 = 0.25f * s43 + s53 + 2.25f * s63; + float m23 = 0.125f * s13 + s23 + 3.375f * s33 + t27; + + float m30 = t30 + t31 + t32 + t33 + t34 + t35 + t36; + float m31 = 0.5f * s14 + s24 + 1.5f * s34; + float m32 = 0.25f * s44 + s54 + 2.25f * s64; + float m33 = 0.125f * s14 + s24 + 3.375f * s34 + t37; + + (dst_data + i)[0] = m00 + bias_data[i]; + (dst_data + i + C4NUM)[0] = m01 + bias_data[i]; + (dst_data + i + 2 * C4NUM)[0] = m02 + bias_data[i]; + (dst_data + i + 3 * C4NUM)[0] = m03 + bias_data[i]; + + (dst_data + i + dst_step * C4NUM)[0] = m10 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 2 * C4NUM)[0] = m12 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 3 * C4NUM)[0] = m13 + bias_data[i]; + + (dst_data + i + 2 * dst_step * C4NUM)[0] = m20 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + C4NUM)[0] = m21 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 2 * C4NUM)[0] = m22 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 3 * C4NUM)[0] = m23 + bias_data[i]; + + (dst_data + i + 3 * dst_step * C4NUM)[0] = m30 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + C4NUM)[0] = m31 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 2 * C4NUM)[0] = m32 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 3 * C4NUM)[0] = m33 + bias_data[i]; + } +#endif +} + +void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_04 = src_data[i + 4 * src_step]; + float src_data_05 = src_data[i + 5 * src_step]; + float src_data_06 = src_data[i + 6 * src_step]; + float src_data_07 = src_data[i + 7 * src_step]; + float src_data_10 = src_data[i + 8 * src_step]; + float src_data_11 = src_data[i + 9 * src_step]; + float src_data_12 = src_data[i + 10 * src_step]; + float src_data_13 = src_data[i + 11 * src_step]; + float src_data_14 = src_data[i + 12 * src_step]; + float src_data_15 = src_data[i + 13 * src_step]; + float src_data_16 = src_data[i + 14 * src_step]; + float src_data_17 = src_data[i + 15 * src_step]; + float src_data_20 = src_data[i + 16 * src_step]; + float src_data_21 = src_data[i + 17 * src_step]; + float src_data_22 = src_data[i + 18 * src_step]; + float src_data_23 = src_data[i + 19 * src_step]; + float src_data_24 = src_data[i + 20 * src_step]; + float src_data_25 = src_data[i + 21 * src_step]; + float src_data_26 = src_data[i + 22 * src_step]; + float src_data_27 = src_data[i + 23 * src_step]; + float src_data_30 = src_data[i + 24 * src_step]; + float src_data_31 = src_data[i + 25 * src_step]; + float src_data_32 = src_data[i + 26 * src_step]; + float src_data_33 = src_data[i + 27 * src_step]; + float src_data_34 = src_data[i + 28 * src_step]; + float src_data_35 = src_data[i + 29 * src_step]; + float src_data_36 = src_data[i + 30 * src_step]; + float src_data_37 = src_data[i + 31 * src_step]; + float src_data_40 = src_data[i + 32 * src_step]; + float src_data_41 = src_data[i + 33 * src_step]; + float src_data_42 = src_data[i + 34 * src_step]; + float src_data_43 = src_data[i + 35 * src_step]; + float src_data_44 = src_data[i + 36 * src_step]; + float src_data_45 = src_data[i + 37 * src_step]; + float src_data_46 = src_data[i + 38 * src_step]; + float src_data_47 = src_data[i + 39 * src_step]; + float src_data_50 = src_data[i + 40 * src_step]; + float src_data_51 = src_data[i + 41 * src_step]; + float src_data_52 = src_data[i + 42 * src_step]; + float src_data_53 = src_data[i + 43 * src_step]; + float src_data_54 = src_data[i + 44 * src_step]; + float src_data_55 = src_data[i + 45 * src_step]; + float src_data_56 = src_data[i + 46 * src_step]; + float src_data_57 = src_data[i + 47 * src_step]; + float src_data_60 = src_data[i + 48 * src_step]; + float src_data_61 = src_data[i + 49 * src_step]; + float src_data_62 = src_data[i + 50 * src_step]; + float src_data_63 = src_data[i + 51 * src_step]; + float src_data_64 = src_data[i + 52 * src_step]; + float src_data_65 = src_data[i + 53 * src_step]; + float src_data_66 = src_data[i + 54 * src_step]; + float src_data_67 = src_data[i + 55 * src_step]; + float src_data_70 = src_data[i + 56 * src_step]; + float src_data_71 = src_data[i + 57 * src_step]; + float src_data_72 = src_data[i + 58 * src_step]; + float src_data_73 = src_data[i + 59 * src_step]; + float src_data_74 = src_data[i + 60 * src_step]; + float src_data_75 = src_data[i + 61 * src_step]; + float src_data_76 = src_data[i + 62 * src_step]; + float src_data_77 = src_data[i + 63 * src_step]; + + float d01 = src_data_10 - src_data_20; + float d02 = src_data_11 - src_data_21; + float d03 = src_data_12 - src_data_22; + float d04 = src_data_13 - src_data_23; + float d05 = src_data_14 - src_data_24; + float d06 = src_data_15 - src_data_25; + float d07 = src_data_16 - src_data_26; + float d08 = src_data_17 - src_data_27; + + float d11 = src_data_30 - src_data_40; + float d12 = src_data_31 - src_data_41; + float d13 = src_data_32 - src_data_42; + float d14 = src_data_33 - src_data_43; + float d15 = src_data_34 - src_data_44; + float d16 = src_data_35 - src_data_45; + float d17 = src_data_36 - src_data_46; + float d18 = src_data_37 - src_data_47; + + float d21 = src_data_50 - src_data_60; + float d22 = src_data_51 - src_data_61; + float d23 = src_data_52 - src_data_62; + float d24 = src_data_53 - src_data_63; + float d25 = src_data_54 - src_data_64; + float d26 = src_data_55 - src_data_65; + float d27 = src_data_56 - src_data_66; + float d28 = src_data_57 - src_data_67; + + float d31 = src_data_10 + src_data_20; + float d32 = src_data_11 + src_data_21; + float d33 = src_data_12 + src_data_22; + float d34 = src_data_13 + src_data_23; + float d35 = src_data_14 + src_data_24; + float d36 = src_data_15 + src_data_25; + float d37 = src_data_16 + src_data_26; + float d38 = src_data_17 + src_data_27; + + float d41 = src_data_30 + src_data_40; + float d42 = src_data_31 + src_data_41; + float d43 = src_data_32 + src_data_42; + float d44 = src_data_33 + src_data_43; + float d45 = src_data_34 + src_data_44; + float d46 = src_data_35 + src_data_45; + float d47 = src_data_36 + src_data_46; + float d48 = src_data_37 + src_data_47; + + float d51 = src_data_50 + src_data_60; + float d52 = src_data_51 + src_data_61; + float d53 = src_data_52 + src_data_62; + float d54 = src_data_53 + src_data_63; + float d55 = src_data_54 + src_data_64; + float d56 = src_data_55 + src_data_65; + float d57 = src_data_56 + src_data_66; + float d58 = src_data_57 + src_data_67; + + float t00 = src_data_00 + src_data_10 + src_data_20 + src_data_30 + src_data_40 + src_data_50 + src_data_60; + float t01 = src_data_01 + src_data_11 + src_data_21 + src_data_31 + src_data_41 + src_data_51 + src_data_61; + float t02 = src_data_02 + src_data_12 + src_data_22 + src_data_32 + src_data_42 + src_data_52 + src_data_62; + float t03 = src_data_03 + src_data_13 + src_data_23 + src_data_33 + src_data_43 + src_data_53 + src_data_63; + float t04 = src_data_04 + src_data_14 + src_data_24 + src_data_34 + src_data_44 + src_data_54 + src_data_64; + float t05 = src_data_05 + src_data_15 + src_data_25 + src_data_35 + src_data_45 + src_data_55 + src_data_65; + float t06 = src_data_06 + src_data_16 + src_data_26 + src_data_36 + src_data_46 + src_data_56 + src_data_66; + float t07 = src_data_07 + src_data_17 + src_data_27 + src_data_37 + src_data_47 + src_data_57 + src_data_67; + + float t10 = 0.5f * d01 + d11 + 1.5f * d21; + float t11 = 0.5f * d02 + d12 + 1.5f * d22; + float t12 = 0.5f * d03 + d13 + 1.5f * d23; + float t13 = 0.5f * d04 + d14 + 1.5f * d24; + float t14 = 0.5f * d05 + d15 + 1.5f * d25; + float t15 = 0.5f * d06 + d16 + 1.5f * d26; + float t16 = 0.5f * d07 + d17 + 1.5f * d27; + float t17 = 0.5f * d08 + d18 + 1.5f * d28; + + float t20 = 0.25f * d31 + d41 + 2.25f * d51; + float t21 = 0.25f * d32 + d42 + 2.25f * d52; + float t22 = 0.25f * d33 + d43 + 2.25f * d53; + float t23 = 0.25f * d34 + d44 + 2.25f * d54; + float t24 = 0.25f * d35 + d45 + 2.25f * d55; + float t25 = 0.25f * d36 + d46 + 2.25f * d56; + float t26 = 0.25f * d37 + d47 + 2.25f * d57; + float t27 = 0.25f * d38 + d48 + 2.25f * d58; + + float t30 = 0.125f * d01 + d11 + 3.375f * d21; + float t31 = 0.125f * d02 + d12 + 3.375f * d22; + float t32 = 0.125f * d03 + d13 + 3.375f * d23; + float t33 = 0.125f * d04 + d14 + 3.375f * d24; + float t34 = 0.125f * d05 + d15 + 3.375f * d25; + float t35 = 0.125f * d06 + d16 + 3.375f * d26; + float t36 = 0.125f * d07 + d17 + 3.375f * d27; + float t37 = 0.125f * d08 + d18 + 3.375f * d28; + + float t40 = 0.0625f * d31 + d41 + 5.0625f * d51 + src_data_70; + float t41 = 0.0625f * d32 + d42 + 5.0625f * d52 + src_data_71; + float t42 = 0.0625f * d33 + d43 + 5.0625f * d53 + src_data_72; + float t43 = 0.0625f * d34 + d44 + 5.0625f * d54 + src_data_73; + float t44 = 0.0625f * d35 + d45 + 5.0625f * d55 + src_data_74; + float t45 = 0.0625f * d36 + d46 + 5.0625f * d56 + src_data_75; + float t46 = 0.0625f * d37 + d47 + 5.0625f * d57 + src_data_76; + float t47 = 0.0625f * d38 + d48 + 5.0625f * d58 + src_data_77; + + float s11 = t01 - t02; + float s12 = t11 - t12; + float s13 = t21 - t22; + float s14 = t31 - t32; + float s15 = t41 - t42; + + float s21 = t03 - t04; + float s22 = t13 - t14; + float s23 = t23 - t24; + float s24 = t33 - t34; + float s25 = t43 - t44; + + float s31 = t05 - t06; + float s32 = t15 - t16; + float s33 = t25 - t26; + float s34 = t35 - t36; + float s35 = t45 - t46; + + float s41 = t01 + t02; + float s42 = t11 + t12; + float s43 = t21 + t22; + float s44 = t31 + t32; + float s45 = t41 + t42; + + float s51 = t03 + t04; + float s52 = t13 + t14; + float s53 = t23 + t24; + float s54 = t33 + t34; + float s55 = t43 + t44; + + float s61 = t05 + t06; + float s62 = t15 + t16; + float s63 = t25 + t26; + float s64 = t35 + t36; + float s65 = t45 + t46; + + float m00 = t00 + t01 + t02 + t03 + t04 + t05 + t06; + float m01 = 0.5f * s11 + s21 + 1.5f * s31; + float m02 = 0.25f * s41 + s51 + 2.25f * s61; + float m03 = 0.125f * s11 + s21 + 3.375f * s31; + float m04 = 0.0625f * s41 + s51 + 5.0625f * s61 + t07; + + float m10 = t10 + t11 + t12 + t13 + t14 + t15 + t16; + float m11 = 0.5f * s12 + s22 + 1.5f * s32; + float m12 = 0.25f * s42 + s52 + 2.25f * s62; + float m13 = 0.125f * s12 + s22 + 3.375f * s32; + float m14 = 0.0625f * s42 + s52 + 5.0625f * s62 + t17; + + float m20 = t20 + t21 + t22 + t23 + t24 + t25 + t26; + float m21 = 0.5f * s13 + s23 + 1.5f * s33; + float m22 = 0.25f * s43 + s53 + 2.25f * s63; + float m23 = 0.125f * s13 + s23 + 3.375f * s33; + float m24 = 0.0625f * s43 + s53 + 5.0625f * s63 + t27; + + float m30 = t30 + t31 + t32 + t33 + t34 + t35 + t36; + float m31 = 0.5f * s14 + s24 + 1.5f * s34; + float m32 = 0.25f * s44 + s54 + 2.25f * s64; + float m33 = 0.125f * s14 + s24 + 3.375f * s34; + float m34 = 0.0625f * s44 + s54 + 5.0625f * s64 + t37; + + float m40 = t40 + t41 + t42 + t43 + t44 + t45 + t46; + float m41 = 0.5f * s15 + s25 + 1.5f * s35; + float m42 = 0.25f * s45 + s55 + 2.25f * s65; + float m43 = 0.125f * s15 + s25 + 3.375f * s35; + float m44 = 0.0625f * s45 + s55 + 5.0625f * s65 + t47; + + (dst_data + i)[0] = m00 + bias_data[i]; + (dst_data + i + C4NUM)[0] = m01 + bias_data[i]; + (dst_data + i + 2 * C4NUM)[0] = m02 + bias_data[i]; + (dst_data + i + 3 * C4NUM)[0] = m03 + bias_data[i]; + (dst_data + i + 4 * C4NUM)[0] = m04 + bias_data[i]; + + (dst_data + i + dst_step * C4NUM)[0] = m10 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 2 * C4NUM)[0] = m12 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 3 * C4NUM)[0] = m13 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 4 * C4NUM)[0] = m14 + bias_data[i]; + + (dst_data + i + 2 * dst_step * C4NUM)[0] = m20 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + C4NUM)[0] = m21 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 2 * C4NUM)[0] = m22 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 3 * C4NUM)[0] = m23 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 4 * C4NUM)[0] = m24 + bias_data[i]; + + (dst_data + i + 3 * dst_step * C4NUM)[0] = m30 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + C4NUM)[0] = m31 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 2 * C4NUM)[0] = m32 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 3 * C4NUM)[0] = m33 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 4 * C4NUM)[0] = m34 + bias_data[i]; + + (dst_data + i + 4 * dst_step * C4NUM)[0] = m40 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + C4NUM)[0] = m41 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 2 * C4NUM)[0] = m42 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 3 * C4NUM)[0] = m43 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 4 * C4NUM)[0] = m44 + bias_data[i]; + } +#endif +} + +void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_04 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_05 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_06 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_07 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_14 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_15 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_16 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_17 = vld1q_f32(src_data + 15 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 16 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 17 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 18 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 19 * src_step); + float32x4_t src_data_24 = vld1q_f32(src_data + 20 * src_step); + float32x4_t src_data_25 = vld1q_f32(src_data + 21 * src_step); + float32x4_t src_data_26 = vld1q_f32(src_data + 22 * src_step); + float32x4_t src_data_27 = vld1q_f32(src_data + 23 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 24 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 25 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 26 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 27 * src_step); + float32x4_t src_data_34 = vld1q_f32(src_data + 28 * src_step); + float32x4_t src_data_35 = vld1q_f32(src_data + 29 * src_step); + float32x4_t src_data_36 = vld1q_f32(src_data + 30 * src_step); + float32x4_t src_data_37 = vld1q_f32(src_data + 31 * src_step); + float32x4_t src_data_40 = vld1q_f32(src_data + 32 * src_step); + float32x4_t src_data_41 = vld1q_f32(src_data + 33 * src_step); + float32x4_t src_data_42 = vld1q_f32(src_data + 34 * src_step); + float32x4_t src_data_43 = vld1q_f32(src_data + 35 * src_step); + float32x4_t src_data_44 = vld1q_f32(src_data + 36 * src_step); + float32x4_t src_data_45 = vld1q_f32(src_data + 37 * src_step); + float32x4_t src_data_46 = vld1q_f32(src_data + 38 * src_step); + float32x4_t src_data_47 = vld1q_f32(src_data + 39 * src_step); + float32x4_t src_data_50 = vld1q_f32(src_data + 40 * src_step); + float32x4_t src_data_51 = vld1q_f32(src_data + 41 * src_step); + float32x4_t src_data_52 = vld1q_f32(src_data + 42 * src_step); + float32x4_t src_data_53 = vld1q_f32(src_data + 43 * src_step); + float32x4_t src_data_54 = vld1q_f32(src_data + 44 * src_step); + float32x4_t src_data_55 = vld1q_f32(src_data + 45 * src_step); + float32x4_t src_data_56 = vld1q_f32(src_data + 46 * src_step); + float32x4_t src_data_57 = vld1q_f32(src_data + 47 * src_step); + float32x4_t src_data_60 = vld1q_f32(src_data + 48 * src_step); + float32x4_t src_data_61 = vld1q_f32(src_data + 49 * src_step); + float32x4_t src_data_62 = vld1q_f32(src_data + 50 * src_step); + float32x4_t src_data_63 = vld1q_f32(src_data + 51 * src_step); + float32x4_t src_data_64 = vld1q_f32(src_data + 52 * src_step); + float32x4_t src_data_65 = vld1q_f32(src_data + 53 * src_step); + float32x4_t src_data_66 = vld1q_f32(src_data + 54 * src_step); + float32x4_t src_data_67 = vld1q_f32(src_data + 55 * src_step); + float32x4_t src_data_70 = vld1q_f32(src_data + 56 * src_step); + float32x4_t src_data_71 = vld1q_f32(src_data + 57 * src_step); + float32x4_t src_data_72 = vld1q_f32(src_data + 58 * src_step); + float32x4_t src_data_73 = vld1q_f32(src_data + 59 * src_step); + float32x4_t src_data_74 = vld1q_f32(src_data + 60 * src_step); + float32x4_t src_data_75 = vld1q_f32(src_data + 61 * src_step); + float32x4_t src_data_76 = vld1q_f32(src_data + 62 * src_step); + float32x4_t src_data_77 = vld1q_f32(src_data + 63 * src_step); + + float32x4_t d01 = vsubq_f32(src_data_10, src_data_20); + float32x4_t d02 = vsubq_f32(src_data_11, src_data_21); + float32x4_t d03 = vsubq_f32(src_data_12, src_data_22); + float32x4_t d04 = vsubq_f32(src_data_13, src_data_23); + float32x4_t d05 = vsubq_f32(src_data_14, src_data_24); + float32x4_t d06 = vsubq_f32(src_data_15, src_data_25); + float32x4_t d07 = vsubq_f32(src_data_16, src_data_26); + float32x4_t d08 = vsubq_f32(src_data_17, src_data_27); + + float32x4_t d11 = vsubq_f32(src_data_30, src_data_40); + float32x4_t d12 = vsubq_f32(src_data_31, src_data_41); + float32x4_t d13 = vsubq_f32(src_data_32, src_data_42); + float32x4_t d14 = vsubq_f32(src_data_33, src_data_43); + float32x4_t d15 = vsubq_f32(src_data_34, src_data_44); + float32x4_t d16 = vsubq_f32(src_data_35, src_data_45); + float32x4_t d17 = vsubq_f32(src_data_36, src_data_46); + float32x4_t d18 = vsubq_f32(src_data_37, src_data_47); + + float32x4_t d21 = vsubq_f32(src_data_50, src_data_60); + float32x4_t d22 = vsubq_f32(src_data_51, src_data_61); + float32x4_t d23 = vsubq_f32(src_data_52, src_data_62); + float32x4_t d24 = vsubq_f32(src_data_53, src_data_63); + float32x4_t d25 = vsubq_f32(src_data_54, src_data_64); + float32x4_t d26 = vsubq_f32(src_data_55, src_data_65); + float32x4_t d27 = vsubq_f32(src_data_56, src_data_66); + float32x4_t d28 = vsubq_f32(src_data_57, src_data_67); + + float32x4_t d31 = vaddq_f32(src_data_10, src_data_20); + float32x4_t d32 = vaddq_f32(src_data_11, src_data_21); + float32x4_t d33 = vaddq_f32(src_data_12, src_data_22); + float32x4_t d34 = vaddq_f32(src_data_13, src_data_23); + float32x4_t d35 = vaddq_f32(src_data_14, src_data_24); + float32x4_t d36 = vaddq_f32(src_data_15, src_data_25); + float32x4_t d37 = vaddq_f32(src_data_16, src_data_26); + float32x4_t d38 = vaddq_f32(src_data_17, src_data_27); + + float32x4_t d41 = vaddq_f32(src_data_30, src_data_40); + float32x4_t d42 = vaddq_f32(src_data_31, src_data_41); + float32x4_t d43 = vaddq_f32(src_data_32, src_data_42); + float32x4_t d44 = vaddq_f32(src_data_33, src_data_43); + float32x4_t d45 = vaddq_f32(src_data_34, src_data_44); + float32x4_t d46 = vaddq_f32(src_data_35, src_data_45); + float32x4_t d47 = vaddq_f32(src_data_36, src_data_46); + float32x4_t d48 = vaddq_f32(src_data_37, src_data_47); + + float32x4_t d51 = vaddq_f32(src_data_50, src_data_60); + float32x4_t d52 = vaddq_f32(src_data_51, src_data_61); + float32x4_t d53 = vaddq_f32(src_data_52, src_data_62); + float32x4_t d54 = vaddq_f32(src_data_53, src_data_63); + float32x4_t d55 = vaddq_f32(src_data_54, src_data_64); + float32x4_t d56 = vaddq_f32(src_data_55, src_data_65); + float32x4_t d57 = vaddq_f32(src_data_56, src_data_66); + float32x4_t d58 = vaddq_f32(src_data_57, src_data_67); + + float32x4_t t00 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_00, src_data_10), src_data_20), src_data_30), src_data_40), + src_data_50), + src_data_60); + float32x4_t t01 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_01, src_data_11), src_data_21), src_data_31), src_data_41), + src_data_51), + src_data_61); + float32x4_t t02 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_02, src_data_12), src_data_22), src_data_32), src_data_42), + src_data_52), + src_data_62); + float32x4_t t03 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_03, src_data_13), src_data_23), src_data_33), src_data_43), + src_data_53), + src_data_63); + float32x4_t t04 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_04, src_data_14), src_data_24), src_data_34), src_data_44), + src_data_54), + src_data_64); + float32x4_t t05 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_05, src_data_15), src_data_25), src_data_35), src_data_45), + src_data_55), + src_data_65); + float32x4_t t06 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_06, src_data_16), src_data_26), src_data_36), src_data_46), + src_data_56), + src_data_66); + float32x4_t t07 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_07, src_data_17), src_data_27), src_data_37), src_data_47), + src_data_57), + src_data_67); + + float32x4_t t10 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.5), d11), vmulq_n_f32(d21, 1.5)); + float32x4_t t11 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.5), d12), vmulq_n_f32(d22, 1.5)); + float32x4_t t12 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.5), d13), vmulq_n_f32(d23, 1.5)); + float32x4_t t13 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.5), d14), vmulq_n_f32(d24, 1.5)); + float32x4_t t14 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.5), d15), vmulq_n_f32(d25, 1.5)); + float32x4_t t15 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.5), d16), vmulq_n_f32(d26, 1.5)); + float32x4_t t16 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.5), d17), vmulq_n_f32(d27, 1.5)); + float32x4_t t17 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.5), d18), vmulq_n_f32(d28, 1.5)); + + float32x4_t t20 = vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.25), d41), vmulq_n_f32(d51, 2.25)); + float32x4_t t21 = vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.25), d42), vmulq_n_f32(d52, 2.25)); + float32x4_t t22 = vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.25), d43), vmulq_n_f32(d53, 2.25)); + float32x4_t t23 = vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.25), d44), vmulq_n_f32(d54, 2.25)); + float32x4_t t24 = vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.25), d45), vmulq_n_f32(d55, 2.25)); + float32x4_t t25 = vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.25), d46), vmulq_n_f32(d56, 2.25)); + float32x4_t t26 = vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.25), d47), vmulq_n_f32(d57, 2.25)); + float32x4_t t27 = vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.25), d48), vmulq_n_f32(d58, 2.25)); + + float32x4_t t30 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.125), d11), vmulq_n_f32(d21, 3.375)); + float32x4_t t31 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.125), d12), vmulq_n_f32(d22, 3.375)); + float32x4_t t32 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.125), d13), vmulq_n_f32(d23, 3.375)); + float32x4_t t33 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.125), d14), vmulq_n_f32(d24, 3.375)); + float32x4_t t34 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.125), d15), vmulq_n_f32(d25, 3.375)); + float32x4_t t35 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.125), d16), vmulq_n_f32(d26, 3.375)); + float32x4_t t36 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.125), d17), vmulq_n_f32(d27, 3.375)); + float32x4_t t37 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.125), d18), vmulq_n_f32(d28, 3.375)); + + float32x4_t t40 = vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.0625), d41), vmulq_n_f32(d51, 5.0625)); + float32x4_t t41 = vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.0625), d42), vmulq_n_f32(d52, 5.0625)); + float32x4_t t42 = vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.0625), d43), vmulq_n_f32(d53, 5.0625)); + float32x4_t t43 = vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.0625), d44), vmulq_n_f32(d54, 5.0625)); + float32x4_t t44 = vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.0625), d45), vmulq_n_f32(d55, 5.0625)); + float32x4_t t45 = vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.0625), d46), vmulq_n_f32(d56, 5.0625)); + float32x4_t t46 = vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.0625), d47), vmulq_n_f32(d57, 5.0625)); + float32x4_t t47 = vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.0625), d48), vmulq_n_f32(d58, 5.0625)); + + float32x4_t t50 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.03125), d11), vmulq_n_f32(d21, 7.59375)), src_data_70); + float32x4_t t51 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.03125), d12), vmulq_n_f32(d22, 7.59375)), src_data_71); + float32x4_t t52 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.03125), d13), vmulq_n_f32(d23, 7.59375)), src_data_72); + float32x4_t t53 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.03125), d14), vmulq_n_f32(d24, 7.59375)), src_data_73); + float32x4_t t54 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.03125), d15), vmulq_n_f32(d25, 7.59375)), src_data_74); + float32x4_t t55 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.03125), d16), vmulq_n_f32(d26, 7.59375)), src_data_75); + float32x4_t t56 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.03125), d17), vmulq_n_f32(d27, 7.59375)), src_data_76); + float32x4_t t57 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.03125), d18), vmulq_n_f32(d28, 7.59375)), src_data_77); + + float32x4_t s11 = vsubq_f32(t01, t02); + float32x4_t s12 = vsubq_f32(t11, t12); + float32x4_t s13 = vsubq_f32(t21, t22); + float32x4_t s14 = vsubq_f32(t31, t32); + float32x4_t s15 = vsubq_f32(t41, t42); + float32x4_t s16 = vsubq_f32(t51, t52); + + float32x4_t s21 = vsubq_f32(t03, t04); + float32x4_t s22 = vsubq_f32(t13, t14); + float32x4_t s23 = vsubq_f32(t23, t24); + float32x4_t s24 = vsubq_f32(t33, t34); + float32x4_t s25 = vsubq_f32(t43, t44); + float32x4_t s26 = vsubq_f32(t53, t54); + + float32x4_t s31 = vsubq_f32(t05, t06); + float32x4_t s32 = vsubq_f32(t15, t16); + float32x4_t s33 = vsubq_f32(t25, t26); + float32x4_t s34 = vsubq_f32(t35, t36); + float32x4_t s35 = vsubq_f32(t45, t46); + float32x4_t s36 = vsubq_f32(t55, t56); + + float32x4_t s41 = vaddq_f32(t01, t02); + float32x4_t s42 = vaddq_f32(t11, t12); + float32x4_t s43 = vaddq_f32(t21, t22); + float32x4_t s44 = vaddq_f32(t31, t32); + float32x4_t s45 = vaddq_f32(t41, t42); + float32x4_t s46 = vaddq_f32(t51, t52); + + float32x4_t s51 = vaddq_f32(t03, t04); + float32x4_t s52 = vaddq_f32(t13, t14); + float32x4_t s53 = vaddq_f32(t23, t24); + float32x4_t s54 = vaddq_f32(t33, t34); + float32x4_t s55 = vaddq_f32(t43, t44); + float32x4_t s56 = vaddq_f32(t53, t54); + + float32x4_t s61 = vaddq_f32(t05, t06); + float32x4_t s62 = vaddq_f32(t15, t16); + float32x4_t s63 = vaddq_f32(t25, t26); + float32x4_t s64 = vaddq_f32(t35, t36); + float32x4_t s65 = vaddq_f32(t45, t46); + float32x4_t s66 = vaddq_f32(t55, t56); + + float32x4_t m00 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), t03), t04), t05), t06); + float32x4_t m01 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.5), s21), vmulq_n_f32(s31, 1.5)); + float32x4_t m02 = vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.25), s51), vmulq_n_f32(s61, 2.25)); + float32x4_t m03 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.125), s21), vmulq_n_f32(s31, 3.375)); + float32x4_t m04 = vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.0625), s51), vmulq_n_f32(s61, 5.0625)); + float32x4_t m05 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.03125), s21), vmulq_n_f32(s31, 7.59375)), t07); + + float32x4_t m10 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), t13), t14), t15), t16); + float32x4_t m11 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.5), s22), vmulq_n_f32(s32, 1.5)); + float32x4_t m12 = vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.25), s52), vmulq_n_f32(s62, 2.25)); + float32x4_t m13 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.125), s22), vmulq_n_f32(s32, 3.375)); + float32x4_t m14 = vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.0625), s52), vmulq_n_f32(s62, 5.0625)); + float32x4_t m15 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.03125), s22), vmulq_n_f32(s32, 7.59375)), t17); + + float32x4_t m20 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t20, t21), t22), t23), t24), t25), t26); + float32x4_t m21 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.5), s23), vmulq_n_f32(s33, 1.5)); + float32x4_t m22 = vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.25), s53), vmulq_n_f32(s63, 2.25)); + float32x4_t m23 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.125), s23), vmulq_n_f32(s33, 3.375)); + float32x4_t m24 = vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.0625), s53), vmulq_n_f32(s63, 5.0625)); + float32x4_t m25 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.03125), s23), vmulq_n_f32(s33, 7.59375)), t27); + + float32x4_t m30 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t30, t31), t32), t33), t34), t35), t36); + float32x4_t m31 = vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.5), s24), vmulq_n_f32(s34, 1.5)); + float32x4_t m32 = vaddq_f32(vaddq_f32(vmulq_n_f32(s44, 0.25), s54), vmulq_n_f32(s64, 2.25)); + float32x4_t m33 = vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.125), s24), vmulq_n_f32(s34, 3.375)); + float32x4_t m34 = vaddq_f32(vaddq_f32(vmulq_n_f32(s44, 0.0625), s54), vmulq_n_f32(s64, 5.0625)); + float32x4_t m35 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.03125), s24), vmulq_n_f32(s34, 7.59375)), t37); + + float32x4_t m40 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t40, t41), t42), t43), t44), t45), t46); + float32x4_t m41 = vaddq_f32(vaddq_f32(vmulq_n_f32(s15, 0.5), s25), vmulq_n_f32(s35, 1.5)); + float32x4_t m42 = vaddq_f32(vaddq_f32(vmulq_n_f32(s45, 0.25), s55), vmulq_n_f32(s65, 2.25)); + float32x4_t m43 = vaddq_f32(vaddq_f32(vmulq_n_f32(s15, 0.125), s25), vmulq_n_f32(s35, 3.375)); + float32x4_t m44 = vaddq_f32(vaddq_f32(vmulq_n_f32(s45, 0.0625), s55), vmulq_n_f32(s65, 5.0625)); + float32x4_t m45 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s15, 0.03125), s25), vmulq_n_f32(s35, 7.59375)), t47); + + float32x4_t m50 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t50, t51), t52), t53), t54), t55), t56); + float32x4_t m51 = vaddq_f32(vaddq_f32(vmulq_n_f32(s16, 0.5), s26), vmulq_n_f32(s36, 1.5)); + float32x4_t m52 = vaddq_f32(vaddq_f32(vmulq_n_f32(s46, 0.25), s56), vmulq_n_f32(s66, 2.25)); + float32x4_t m53 = vaddq_f32(vaddq_f32(vmulq_n_f32(s16, 0.125), s26), vmulq_n_f32(s36, 3.375)); + float32x4_t m54 = vaddq_f32(vaddq_f32(vmulq_n_f32(s46, 0.0625), s56), vmulq_n_f32(s66, 5.0625)); + float32x4_t m55 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s16, 0.03125), s26), vmulq_n_f32(s36, 7.59375)), t57); + + float32x4_t bias_ptr = vld1q_f32(bias_data); + vst1q_f32(dst_data, vaddq_f32(m00, bias_ptr)); + vst1q_f32(dst_data + C4NUM, vaddq_f32(m01, bias_ptr)); + vst1q_f32(dst_data + 2 * C4NUM, vaddq_f32(m02, bias_ptr)); + vst1q_f32(dst_data + 3 * C4NUM, vaddq_f32(m03, bias_ptr)); + vst1q_f32(dst_data + 4 * C4NUM, vaddq_f32(m04, bias_ptr)); + vst1q_f32(dst_data + 5 * C4NUM, vaddq_f32(m05, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM, vaddq_f32(m10, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, vaddq_f32(m11, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m12, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m13, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m14, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m15, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM, vaddq_f32(m20, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, vaddq_f32(m21, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m22, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m23, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m24, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m25, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM, vaddq_f32(m30, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + C4NUM, vaddq_f32(m31, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m32, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m33, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m34, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m35, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM, vaddq_f32(m40, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + C4NUM, vaddq_f32(m41, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m42, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m43, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m44, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m45, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM, vaddq_f32(m50, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + C4NUM, vaddq_f32(m51, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m52, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m53, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m54, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m55, bias_ptr)); +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_04 = src_data[i + 4 * src_step]; + float src_data_05 = src_data[i + 5 * src_step]; + float src_data_06 = src_data[i + 6 * src_step]; + float src_data_07 = src_data[i + 7 * src_step]; + float src_data_10 = src_data[i + 8 * src_step]; + float src_data_11 = src_data[i + 9 * src_step]; + float src_data_12 = src_data[i + 10 * src_step]; + float src_data_13 = src_data[i + 11 * src_step]; + float src_data_14 = src_data[i + 12 * src_step]; + float src_data_15 = src_data[i + 13 * src_step]; + float src_data_16 = src_data[i + 14 * src_step]; + float src_data_17 = src_data[i + 15 * src_step]; + float src_data_20 = src_data[i + 16 * src_step]; + float src_data_21 = src_data[i + 17 * src_step]; + float src_data_22 = src_data[i + 18 * src_step]; + float src_data_23 = src_data[i + 19 * src_step]; + float src_data_24 = src_data[i + 20 * src_step]; + float src_data_25 = src_data[i + 21 * src_step]; + float src_data_26 = src_data[i + 22 * src_step]; + float src_data_27 = src_data[i + 23 * src_step]; + float src_data_30 = src_data[i + 24 * src_step]; + float src_data_31 = src_data[i + 25 * src_step]; + float src_data_32 = src_data[i + 26 * src_step]; + float src_data_33 = src_data[i + 27 * src_step]; + float src_data_34 = src_data[i + 28 * src_step]; + float src_data_35 = src_data[i + 29 * src_step]; + float src_data_36 = src_data[i + 30 * src_step]; + float src_data_37 = src_data[i + 31 * src_step]; + float src_data_40 = src_data[i + 32 * src_step]; + float src_data_41 = src_data[i + 33 * src_step]; + float src_data_42 = src_data[i + 34 * src_step]; + float src_data_43 = src_data[i + 35 * src_step]; + float src_data_44 = src_data[i + 36 * src_step]; + float src_data_45 = src_data[i + 37 * src_step]; + float src_data_46 = src_data[i + 38 * src_step]; + float src_data_47 = src_data[i + 39 * src_step]; + float src_data_50 = src_data[i + 40 * src_step]; + float src_data_51 = src_data[i + 41 * src_step]; + float src_data_52 = src_data[i + 42 * src_step]; + float src_data_53 = src_data[i + 43 * src_step]; + float src_data_54 = src_data[i + 44 * src_step]; + float src_data_55 = src_data[i + 45 * src_step]; + float src_data_56 = src_data[i + 46 * src_step]; + float src_data_57 = src_data[i + 47 * src_step]; + float src_data_60 = src_data[i + 48 * src_step]; + float src_data_61 = src_data[i + 49 * src_step]; + float src_data_62 = src_data[i + 50 * src_step]; + float src_data_63 = src_data[i + 51 * src_step]; + float src_data_64 = src_data[i + 52 * src_step]; + float src_data_65 = src_data[i + 53 * src_step]; + float src_data_66 = src_data[i + 54 * src_step]; + float src_data_67 = src_data[i + 55 * src_step]; + float src_data_70 = src_data[i + 56 * src_step]; + float src_data_71 = src_data[i + 57 * src_step]; + float src_data_72 = src_data[i + 58 * src_step]; + float src_data_73 = src_data[i + 59 * src_step]; + float src_data_74 = src_data[i + 60 * src_step]; + float src_data_75 = src_data[i + 61 * src_step]; + float src_data_76 = src_data[i + 62 * src_step]; + float src_data_77 = src_data[i + 63 * src_step]; + + float d01 = src_data_10 - src_data_20; + float d02 = src_data_11 - src_data_21; + float d03 = src_data_12 - src_data_22; + float d04 = src_data_13 - src_data_23; + float d05 = src_data_14 - src_data_24; + float d06 = src_data_15 - src_data_25; + float d07 = src_data_16 - src_data_26; + float d08 = src_data_17 - src_data_27; + + float d11 = src_data_30 - src_data_40; + float d12 = src_data_31 - src_data_41; + float d13 = src_data_32 - src_data_42; + float d14 = src_data_33 - src_data_43; + float d15 = src_data_34 - src_data_44; + float d16 = src_data_35 - src_data_45; + float d17 = src_data_36 - src_data_46; + float d18 = src_data_37 - src_data_47; + + float d21 = src_data_50 - src_data_60; + float d22 = src_data_51 - src_data_61; + float d23 = src_data_52 - src_data_62; + float d24 = src_data_53 - src_data_63; + float d25 = src_data_54 - src_data_64; + float d26 = src_data_55 - src_data_65; + float d27 = src_data_56 - src_data_66; + float d28 = src_data_57 - src_data_67; + + float d31 = src_data_10 + src_data_20; + float d32 = src_data_11 + src_data_21; + float d33 = src_data_12 + src_data_22; + float d34 = src_data_13 + src_data_23; + float d35 = src_data_14 + src_data_24; + float d36 = src_data_15 + src_data_25; + float d37 = src_data_16 + src_data_26; + float d38 = src_data_17 + src_data_27; + + float d41 = src_data_30 + src_data_40; + float d42 = src_data_31 + src_data_41; + float d43 = src_data_32 + src_data_42; + float d44 = src_data_33 + src_data_43; + float d45 = src_data_34 + src_data_44; + float d46 = src_data_35 + src_data_45; + float d47 = src_data_36 + src_data_46; + float d48 = src_data_37 + src_data_47; + + float d51 = src_data_50 + src_data_60; + float d52 = src_data_51 + src_data_61; + float d53 = src_data_52 + src_data_62; + float d54 = src_data_53 + src_data_63; + float d55 = src_data_54 + src_data_64; + float d56 = src_data_55 + src_data_65; + float d57 = src_data_56 + src_data_66; + float d58 = src_data_57 + src_data_67; + + float t00 = src_data_00 + src_data_10 + src_data_20 + src_data_30 + src_data_40 + src_data_50 + src_data_60; + float t01 = src_data_01 + src_data_11 + src_data_21 + src_data_31 + src_data_41 + src_data_51 + src_data_61; + float t02 = src_data_02 + src_data_12 + src_data_22 + src_data_32 + src_data_42 + src_data_52 + src_data_62; + float t03 = src_data_03 + src_data_13 + src_data_23 + src_data_33 + src_data_43 + src_data_53 + src_data_63; + float t04 = src_data_04 + src_data_14 + src_data_24 + src_data_34 + src_data_44 + src_data_54 + src_data_64; + float t05 = src_data_05 + src_data_15 + src_data_25 + src_data_35 + src_data_45 + src_data_55 + src_data_65; + float t06 = src_data_06 + src_data_16 + src_data_26 + src_data_36 + src_data_46 + src_data_56 + src_data_66; + float t07 = src_data_07 + src_data_17 + src_data_27 + src_data_37 + src_data_47 + src_data_57 + src_data_67; + + float t10 = 0.5f * d01 + d11 + 1.5f * d21; + float t11 = 0.5f * d02 + d12 + 1.5f * d22; + float t12 = 0.5f * d03 + d13 + 1.5f * d23; + float t13 = 0.5f * d04 + d14 + 1.5f * d24; + float t14 = 0.5f * d05 + d15 + 1.5f * d25; + float t15 = 0.5f * d06 + d16 + 1.5f * d26; + float t16 = 0.5f * d07 + d17 + 1.5f * d27; + float t17 = 0.5f * d08 + d18 + 1.5f * d28; + + float t20 = 0.25f * d31 + d41 + 2.25f * d51; + float t21 = 0.25f * d32 + d42 + 2.25f * d52; + float t22 = 0.25f * d33 + d43 + 2.25f * d53; + float t23 = 0.25f * d34 + d44 + 2.25f * d54; + float t24 = 0.25f * d35 + d45 + 2.25f * d55; + float t25 = 0.25f * d36 + d46 + 2.25f * d56; + float t26 = 0.25f * d37 + d47 + 2.25f * d57; + float t27 = 0.25f * d38 + d48 + 2.25f * d58; + + float t30 = 0.125f * d01 + d11 + 3.375f * d21; + float t31 = 0.125f * d02 + d12 + 3.375f * d22; + float t32 = 0.125f * d03 + d13 + 3.375f * d23; + float t33 = 0.125f * d04 + d14 + 3.375f * d24; + float t34 = 0.125f * d05 + d15 + 3.375f * d25; + float t35 = 0.125f * d06 + d16 + 3.375f * d26; + float t36 = 0.125f * d07 + d17 + 3.375f * d27; + float t37 = 0.125f * d08 + d18 + 3.375f * d28; + + float t40 = 0.0625f * d31 + d41 + 5.0625f * d51; + float t41 = 0.0625f * d32 + d42 + 5.0625f * d52; + float t42 = 0.0625f * d33 + d43 + 5.0625f * d53; + float t43 = 0.0625f * d34 + d44 + 5.0625f * d54; + float t44 = 0.0625f * d35 + d45 + 5.0625f * d55; + float t45 = 0.0625f * d36 + d46 + 5.0625f * d56; + float t46 = 0.0625f * d37 + d47 + 5.0625f * d57; + float t47 = 0.0625f * d38 + d48 + 5.0625f * d58; + + float t50 = 0.03125f * d01 + d11 + 7.59375f * d21 + src_data_70; + float t51 = 0.03125f * d02 + d12 + 7.59375f * d22 + src_data_71; + float t52 = 0.03125f * d03 + d13 + 7.59375f * d23 + src_data_72; + float t53 = 0.03125f * d04 + d14 + 7.59375f * d24 + src_data_73; + float t54 = 0.03125f * d05 + d15 + 7.59375f * d25 + src_data_74; + float t55 = 0.03125f * d06 + d16 + 7.59375f * d26 + src_data_75; + float t56 = 0.03125f * d07 + d17 + 7.59375f * d27 + src_data_76; + float t57 = 0.03125f * d08 + d18 + 7.59375f * d28 + src_data_77; + + float s11 = t01 - t02; + float s12 = t11 - t12; + float s13 = t21 - t22; + float s14 = t31 - t32; + float s15 = t41 - t42; + float s16 = t51 - t52; + + float s21 = t03 - t04; + float s22 = t13 - t14; + float s23 = t23 - t24; + float s24 = t33 - t34; + float s25 = t43 - t44; + float s26 = t53 - t54; + + float s31 = t05 - t06; + float s32 = t15 - t16; + float s33 = t25 - t26; + float s34 = t35 - t36; + float s35 = t45 - t46; + float s36 = t55 - t56; + + float s41 = t01 + t02; + float s42 = t11 + t12; + float s43 = t21 + t22; + float s44 = t31 + t32; + float s45 = t41 + t42; + float s46 = t51 + t52; + + float s51 = t03 + t04; + float s52 = t13 + t14; + float s53 = t23 + t24; + float s54 = t33 + t34; + float s55 = t43 + t44; + float s56 = t53 + t54; + + float s61 = t05 + t06; + float s62 = t15 + t16; + float s63 = t25 + t26; + float s64 = t35 + t36; + float s65 = t45 + t46; + float s66 = t55 + t56; + + float m00 = t00 + t01 + t02 + t03 + t04 + t05 + t06; + float m01 = 0.5f * s11 + s21 + 1.5f * s31; + float m02 = 0.25f * s41 + s51 + 2.25f * s61; + float m03 = 0.125f * s11 + s21 + 3.375f * s31; + float m04 = 0.0625f * s41 + s51 + 5.0625f * s61; + float m05 = 0.03125f * s11 + s21 + 7.59375f * s31 + t07; + + float m10 = t10 + t11 + t12 + t13 + t14 + t15 + t16; + float m11 = 0.5f * s12 + s22 + 1.5f * s32; + float m12 = 0.25f * s42 + s52 + 2.25f * s62; + float m13 = 0.125f * s12 + s22 + 3.375f * s32; + float m14 = 0.0625f * s42 + s52 + 5.0625f * s62; + float m15 = 0.03125f * s12 + s22 + 7.59375f * s32 + t17; + + float m20 = t20 + t21 + t22 + t23 + t24 + t25 + t26; + float m21 = 0.5f * s13 + s23 + 1.5f * s33; + float m22 = 0.25f * s43 + s53 + 2.25f * s63; + float m23 = 0.125f * s13 + s23 + 3.375f * s33; + float m24 = 0.0625f * s43 + s53 + 5.0625f * s63; + float m25 = 0.03125f * s13 + s23 + 7.59375f * s33 + t27; + + float m30 = t30 + t31 + t32 + t33 + t34 + t35 + t36; + float m31 = 0.5f * s14 + s24 + 1.5f * s34; + float m32 = 0.25f * s44 + s54 + 2.25f * s64; + float m33 = 0.125f * s14 + s24 + 3.375f * s34; + float m34 = 0.0625f * s44 + s54 + 5.0625f * s64; + float m35 = 0.03125f * s14 + s24 + 7.59375f * s34 + t37; + + float m40 = t40 + t41 + t42 + t43 + t44 + t45 + t46; + float m41 = 0.5f * s15 + s25 + 1.5f * s35; + float m42 = 0.25f * s45 + s55 + 2.25f * s65; + float m43 = 0.125f * s15 + s25 + 3.375f * s35; + float m44 = 0.0625f * s45 + s55 + 5.0625f * s65; + float m45 = 0.03125f * s15 + s25 + 7.59375f * s35 + t47; + + float m50 = t50 + t51 + t52 + t53 + t54 + t55 + t56; + float m51 = 0.5f * s16 + s26 + 1.5f * s36; + float m52 = 0.25f * s46 + s56 + 2.25f * s66; + float m53 = 0.125f * s16 + s26 + 3.375f * s36; + float m54 = 0.0625f * s46 + s56 + 5.0625f * s66; + float m55 = 0.03125f * s16 + s26 + 7.59375f * s36 + t57; + + (dst_data + i)[0] = m00 + bias_data[i]; + (dst_data + i + C4NUM)[0] = m01 + bias_data[i]; + (dst_data + i + 2 * C4NUM)[0] = m02 + bias_data[i]; + (dst_data + i + 3 * C4NUM)[0] = m03 + bias_data[i]; + (dst_data + i + 4 * C4NUM)[0] = m04 + bias_data[i]; + (dst_data + i + 5 * C4NUM)[0] = m05 + bias_data[i]; + + (dst_data + i + dst_step * C4NUM)[0] = m10 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 2 * C4NUM)[0] = m12 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 3 * C4NUM)[0] = m13 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 4 * C4NUM)[0] = m14 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 5 * C4NUM)[0] = m15 + bias_data[i]; + + (dst_data + i + 2 * dst_step * C4NUM)[0] = m20 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + C4NUM)[0] = m21 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 2 * C4NUM)[0] = m22 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 3 * C4NUM)[0] = m23 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 4 * C4NUM)[0] = m24 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 5 * C4NUM)[0] = m25 + bias_data[i]; + + (dst_data + i + 3 * dst_step * C4NUM)[0] = m30 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + C4NUM)[0] = m31 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 2 * C4NUM)[0] = m32 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 3 * C4NUM)[0] = m33 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 4 * C4NUM)[0] = m34 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 5 * C4NUM)[0] = m35 + bias_data[i]; + + (dst_data + i + 4 * dst_step * C4NUM)[0] = m40 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + C4NUM)[0] = m41 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 2 * C4NUM)[0] = m42 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 3 * C4NUM)[0] = m43 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 4 * C4NUM)[0] = m44 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 5 * C4NUM)[0] = m45 + bias_data[i]; + + (dst_data + i + 5 * dst_step * C4NUM)[0] = m50 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + C4NUM)[0] = m51 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 2 * C4NUM)[0] = m52 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 3 * C4NUM)[0] = m53 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 4 * C4NUM)[0] = m54 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 5 * C4NUM)[0] = m55 + bias_data[i]; + } +#endif +} + +void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step) { +#ifdef ENABLE_ARM + float32x4_t src_data_00 = vld1q_f32(src_data + 0 * src_step); + float32x4_t src_data_01 = vld1q_f32(src_data + 1 * src_step); + float32x4_t src_data_02 = vld1q_f32(src_data + 2 * src_step); + float32x4_t src_data_03 = vld1q_f32(src_data + 3 * src_step); + float32x4_t src_data_04 = vld1q_f32(src_data + 4 * src_step); + float32x4_t src_data_05 = vld1q_f32(src_data + 5 * src_step); + float32x4_t src_data_06 = vld1q_f32(src_data + 6 * src_step); + float32x4_t src_data_07 = vld1q_f32(src_data + 7 * src_step); + float32x4_t src_data_10 = vld1q_f32(src_data + 8 * src_step); + float32x4_t src_data_11 = vld1q_f32(src_data + 9 * src_step); + float32x4_t src_data_12 = vld1q_f32(src_data + 10 * src_step); + float32x4_t src_data_13 = vld1q_f32(src_data + 11 * src_step); + float32x4_t src_data_14 = vld1q_f32(src_data + 12 * src_step); + float32x4_t src_data_15 = vld1q_f32(src_data + 13 * src_step); + float32x4_t src_data_16 = vld1q_f32(src_data + 14 * src_step); + float32x4_t src_data_17 = vld1q_f32(src_data + 15 * src_step); + float32x4_t src_data_20 = vld1q_f32(src_data + 16 * src_step); + float32x4_t src_data_21 = vld1q_f32(src_data + 17 * src_step); + float32x4_t src_data_22 = vld1q_f32(src_data + 18 * src_step); + float32x4_t src_data_23 = vld1q_f32(src_data + 19 * src_step); + float32x4_t src_data_24 = vld1q_f32(src_data + 20 * src_step); + float32x4_t src_data_25 = vld1q_f32(src_data + 21 * src_step); + float32x4_t src_data_26 = vld1q_f32(src_data + 22 * src_step); + float32x4_t src_data_27 = vld1q_f32(src_data + 23 * src_step); + float32x4_t src_data_30 = vld1q_f32(src_data + 24 * src_step); + float32x4_t src_data_31 = vld1q_f32(src_data + 25 * src_step); + float32x4_t src_data_32 = vld1q_f32(src_data + 26 * src_step); + float32x4_t src_data_33 = vld1q_f32(src_data + 27 * src_step); + float32x4_t src_data_34 = vld1q_f32(src_data + 28 * src_step); + float32x4_t src_data_35 = vld1q_f32(src_data + 29 * src_step); + float32x4_t src_data_36 = vld1q_f32(src_data + 30 * src_step); + float32x4_t src_data_37 = vld1q_f32(src_data + 31 * src_step); + float32x4_t src_data_40 = vld1q_f32(src_data + 32 * src_step); + float32x4_t src_data_41 = vld1q_f32(src_data + 33 * src_step); + float32x4_t src_data_42 = vld1q_f32(src_data + 34 * src_step); + float32x4_t src_data_43 = vld1q_f32(src_data + 35 * src_step); + float32x4_t src_data_44 = vld1q_f32(src_data + 36 * src_step); + float32x4_t src_data_45 = vld1q_f32(src_data + 37 * src_step); + float32x4_t src_data_46 = vld1q_f32(src_data + 38 * src_step); + float32x4_t src_data_47 = vld1q_f32(src_data + 39 * src_step); + float32x4_t src_data_50 = vld1q_f32(src_data + 40 * src_step); + float32x4_t src_data_51 = vld1q_f32(src_data + 41 * src_step); + float32x4_t src_data_52 = vld1q_f32(src_data + 42 * src_step); + float32x4_t src_data_53 = vld1q_f32(src_data + 43 * src_step); + float32x4_t src_data_54 = vld1q_f32(src_data + 44 * src_step); + float32x4_t src_data_55 = vld1q_f32(src_data + 45 * src_step); + float32x4_t src_data_56 = vld1q_f32(src_data + 46 * src_step); + float32x4_t src_data_57 = vld1q_f32(src_data + 47 * src_step); + float32x4_t src_data_60 = vld1q_f32(src_data + 48 * src_step); + float32x4_t src_data_61 = vld1q_f32(src_data + 49 * src_step); + float32x4_t src_data_62 = vld1q_f32(src_data + 50 * src_step); + float32x4_t src_data_63 = vld1q_f32(src_data + 51 * src_step); + float32x4_t src_data_64 = vld1q_f32(src_data + 52 * src_step); + float32x4_t src_data_65 = vld1q_f32(src_data + 53 * src_step); + float32x4_t src_data_66 = vld1q_f32(src_data + 54 * src_step); + float32x4_t src_data_67 = vld1q_f32(src_data + 55 * src_step); + float32x4_t src_data_70 = vld1q_f32(src_data + 56 * src_step); + float32x4_t src_data_71 = vld1q_f32(src_data + 57 * src_step); + float32x4_t src_data_72 = vld1q_f32(src_data + 58 * src_step); + float32x4_t src_data_73 = vld1q_f32(src_data + 59 * src_step); + float32x4_t src_data_74 = vld1q_f32(src_data + 60 * src_step); + float32x4_t src_data_75 = vld1q_f32(src_data + 61 * src_step); + float32x4_t src_data_76 = vld1q_f32(src_data + 62 * src_step); + float32x4_t src_data_77 = vld1q_f32(src_data + 63 * src_step); + + float32x4_t d01 = vsubq_f32(src_data_10, src_data_20); + float32x4_t d02 = vsubq_f32(src_data_11, src_data_21); + float32x4_t d03 = vsubq_f32(src_data_12, src_data_22); + float32x4_t d04 = vsubq_f32(src_data_13, src_data_23); + float32x4_t d05 = vsubq_f32(src_data_14, src_data_24); + float32x4_t d06 = vsubq_f32(src_data_15, src_data_25); + float32x4_t d07 = vsubq_f32(src_data_16, src_data_26); + float32x4_t d08 = vsubq_f32(src_data_17, src_data_27); + + float32x4_t d11 = vsubq_f32(src_data_30, src_data_40); + float32x4_t d12 = vsubq_f32(src_data_31, src_data_41); + float32x4_t d13 = vsubq_f32(src_data_32, src_data_42); + float32x4_t d14 = vsubq_f32(src_data_33, src_data_43); + float32x4_t d15 = vsubq_f32(src_data_34, src_data_44); + float32x4_t d16 = vsubq_f32(src_data_35, src_data_45); + float32x4_t d17 = vsubq_f32(src_data_36, src_data_46); + float32x4_t d18 = vsubq_f32(src_data_37, src_data_47); + + float32x4_t d21 = vsubq_f32(src_data_50, src_data_60); + float32x4_t d22 = vsubq_f32(src_data_51, src_data_61); + float32x4_t d23 = vsubq_f32(src_data_52, src_data_62); + float32x4_t d24 = vsubq_f32(src_data_53, src_data_63); + float32x4_t d25 = vsubq_f32(src_data_54, src_data_64); + float32x4_t d26 = vsubq_f32(src_data_55, src_data_65); + float32x4_t d27 = vsubq_f32(src_data_56, src_data_66); + float32x4_t d28 = vsubq_f32(src_data_57, src_data_67); + + float32x4_t d31 = vaddq_f32(src_data_10, src_data_20); + float32x4_t d32 = vaddq_f32(src_data_11, src_data_21); + float32x4_t d33 = vaddq_f32(src_data_12, src_data_22); + float32x4_t d34 = vaddq_f32(src_data_13, src_data_23); + float32x4_t d35 = vaddq_f32(src_data_14, src_data_24); + float32x4_t d36 = vaddq_f32(src_data_15, src_data_25); + float32x4_t d37 = vaddq_f32(src_data_16, src_data_26); + float32x4_t d38 = vaddq_f32(src_data_17, src_data_27); + + float32x4_t d41 = vaddq_f32(src_data_30, src_data_40); + float32x4_t d42 = vaddq_f32(src_data_31, src_data_41); + float32x4_t d43 = vaddq_f32(src_data_32, src_data_42); + float32x4_t d44 = vaddq_f32(src_data_33, src_data_43); + float32x4_t d45 = vaddq_f32(src_data_34, src_data_44); + float32x4_t d46 = vaddq_f32(src_data_35, src_data_45); + float32x4_t d47 = vaddq_f32(src_data_36, src_data_46); + float32x4_t d48 = vaddq_f32(src_data_37, src_data_47); + + float32x4_t d51 = vaddq_f32(src_data_50, src_data_60); + float32x4_t d52 = vaddq_f32(src_data_51, src_data_61); + float32x4_t d53 = vaddq_f32(src_data_52, src_data_62); + float32x4_t d54 = vaddq_f32(src_data_53, src_data_63); + float32x4_t d55 = vaddq_f32(src_data_54, src_data_64); + float32x4_t d56 = vaddq_f32(src_data_55, src_data_65); + float32x4_t d57 = vaddq_f32(src_data_56, src_data_66); + float32x4_t d58 = vaddq_f32(src_data_57, src_data_67); + + float32x4_t t00 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_00, src_data_10), src_data_20), src_data_30), src_data_40), + src_data_50), + src_data_60); + float32x4_t t01 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_01, src_data_11), src_data_21), src_data_31), src_data_41), + src_data_51), + src_data_61); + float32x4_t t02 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_02, src_data_12), src_data_22), src_data_32), src_data_42), + src_data_52), + src_data_62); + float32x4_t t03 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_03, src_data_13), src_data_23), src_data_33), src_data_43), + src_data_53), + src_data_63); + float32x4_t t04 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_04, src_data_14), src_data_24), src_data_34), src_data_44), + src_data_54), + src_data_64); + float32x4_t t05 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_05, src_data_15), src_data_25), src_data_35), src_data_45), + src_data_55), + src_data_65); + float32x4_t t06 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_06, src_data_16), src_data_26), src_data_36), src_data_46), + src_data_56), + src_data_66); + float32x4_t t07 = vaddq_f32( + vaddq_f32( + vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(src_data_07, src_data_17), src_data_27), src_data_37), src_data_47), + src_data_57), + src_data_67); + + float32x4_t t10 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.5), d11), vmulq_n_f32(d21, 1.5)); + float32x4_t t11 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.5), d12), vmulq_n_f32(d22, 1.5)); + float32x4_t t12 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.5), d13), vmulq_n_f32(d23, 1.5)); + float32x4_t t13 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.5), d14), vmulq_n_f32(d24, 1.5)); + float32x4_t t14 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.5), d15), vmulq_n_f32(d25, 1.5)); + float32x4_t t15 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.5), d16), vmulq_n_f32(d26, 1.5)); + float32x4_t t16 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.5), d17), vmulq_n_f32(d27, 1.5)); + float32x4_t t17 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.5), d18), vmulq_n_f32(d28, 1.5)); + + float32x4_t t20 = vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.25), d41), vmulq_n_f32(d51, 2.25)); + float32x4_t t21 = vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.25), d42), vmulq_n_f32(d52, 2.25)); + float32x4_t t22 = vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.25), d43), vmulq_n_f32(d53, 2.25)); + float32x4_t t23 = vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.25), d44), vmulq_n_f32(d54, 2.25)); + float32x4_t t24 = vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.25), d45), vmulq_n_f32(d55, 2.25)); + float32x4_t t25 = vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.25), d46), vmulq_n_f32(d56, 2.25)); + float32x4_t t26 = vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.25), d47), vmulq_n_f32(d57, 2.25)); + float32x4_t t27 = vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.25), d48), vmulq_n_f32(d58, 2.25)); + + float32x4_t t30 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.125), d11), vmulq_n_f32(d21, 3.375)); + float32x4_t t31 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.125), d12), vmulq_n_f32(d22, 3.375)); + float32x4_t t32 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.125), d13), vmulq_n_f32(d23, 3.375)); + float32x4_t t33 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.125), d14), vmulq_n_f32(d24, 3.375)); + float32x4_t t34 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.125), d15), vmulq_n_f32(d25, 3.375)); + float32x4_t t35 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.125), d16), vmulq_n_f32(d26, 3.375)); + float32x4_t t36 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.125), d17), vmulq_n_f32(d27, 3.375)); + float32x4_t t37 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.125), d18), vmulq_n_f32(d28, 3.375)); + + float32x4_t t40 = vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.0625), d41), vmulq_n_f32(d51, 5.0625)); + float32x4_t t41 = vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.0625), d42), vmulq_n_f32(d52, 5.0625)); + float32x4_t t42 = vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.0625), d43), vmulq_n_f32(d53, 5.0625)); + float32x4_t t43 = vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.0625), d44), vmulq_n_f32(d54, 5.0625)); + float32x4_t t44 = vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.0625), d45), vmulq_n_f32(d55, 5.0625)); + float32x4_t t45 = vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.0625), d46), vmulq_n_f32(d56, 5.0625)); + float32x4_t t46 = vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.0625), d47), vmulq_n_f32(d57, 5.0625)); + float32x4_t t47 = vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.0625), d48), vmulq_n_f32(d58, 5.0625)); + + float32x4_t t50 = vaddq_f32(vaddq_f32(vmulq_n_f32(d01, 0.03125), d11), vmulq_n_f32(d21, 7.59375)); + float32x4_t t51 = vaddq_f32(vaddq_f32(vmulq_n_f32(d02, 0.03125), d12), vmulq_n_f32(d22, 7.59375)); + float32x4_t t52 = vaddq_f32(vaddq_f32(vmulq_n_f32(d03, 0.03125), d13), vmulq_n_f32(d23, 7.59375)); + float32x4_t t53 = vaddq_f32(vaddq_f32(vmulq_n_f32(d04, 0.03125), d14), vmulq_n_f32(d24, 7.59375)); + float32x4_t t54 = vaddq_f32(vaddq_f32(vmulq_n_f32(d05, 0.03125), d15), vmulq_n_f32(d25, 7.59375)); + float32x4_t t55 = vaddq_f32(vaddq_f32(vmulq_n_f32(d06, 0.03125), d16), vmulq_n_f32(d26, 7.59375)); + float32x4_t t56 = vaddq_f32(vaddq_f32(vmulq_n_f32(d07, 0.03125), d17), vmulq_n_f32(d27, 7.59375)); + float32x4_t t57 = vaddq_f32(vaddq_f32(vmulq_n_f32(d08, 0.03125), d18), vmulq_n_f32(d28, 7.59375)); + + float32x4_t t60 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d31, 0.015625), d41), vmulq_n_f32(d51, 11.390625)), src_data_70); + float32x4_t t61 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d32, 0.015625), d42), vmulq_n_f32(d52, 11.390625)), src_data_71); + float32x4_t t62 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d33, 0.015625), d43), vmulq_n_f32(d53, 11.390625)), src_data_72); + float32x4_t t63 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d34, 0.015625), d44), vmulq_n_f32(d54, 11.390625)), src_data_73); + float32x4_t t64 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d35, 0.015625), d45), vmulq_n_f32(d55, 11.390625)), src_data_74); + float32x4_t t65 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d36, 0.015625), d46), vmulq_n_f32(d56, 11.390625)), src_data_75); + float32x4_t t66 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d37, 0.015625), d47), vmulq_n_f32(d57, 11.390625)), src_data_76); + float32x4_t t67 = + vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(d38, 0.015625), d48), vmulq_n_f32(d58, 11.390625)), src_data_77); + + float32x4_t s11 = vsubq_f32(t01, t02); + float32x4_t s12 = vsubq_f32(t11, t12); + float32x4_t s13 = vsubq_f32(t21, t22); + float32x4_t s14 = vsubq_f32(t31, t32); + float32x4_t s15 = vsubq_f32(t41, t42); + float32x4_t s16 = vsubq_f32(t51, t52); + float32x4_t s17 = vsubq_f32(t61, t62); + + float32x4_t s21 = vsubq_f32(t03, t04); + float32x4_t s22 = vsubq_f32(t13, t14); + float32x4_t s23 = vsubq_f32(t23, t24); + float32x4_t s24 = vsubq_f32(t33, t34); + float32x4_t s25 = vsubq_f32(t43, t44); + float32x4_t s26 = vsubq_f32(t53, t54); + float32x4_t s27 = vsubq_f32(t63, t64); + + float32x4_t s31 = vsubq_f32(t05, t06); + float32x4_t s32 = vsubq_f32(t15, t16); + float32x4_t s33 = vsubq_f32(t25, t26); + float32x4_t s34 = vsubq_f32(t35, t36); + float32x4_t s35 = vsubq_f32(t45, t46); + float32x4_t s36 = vsubq_f32(t55, t56); + float32x4_t s37 = vsubq_f32(t65, t66); + + float32x4_t s41 = vaddq_f32(t01, t02); + float32x4_t s42 = vaddq_f32(t11, t12); + float32x4_t s43 = vaddq_f32(t21, t22); + float32x4_t s44 = vaddq_f32(t31, t32); + float32x4_t s45 = vaddq_f32(t41, t42); + float32x4_t s46 = vaddq_f32(t51, t52); + float32x4_t s47 = vaddq_f32(t61, t62); + + float32x4_t s51 = vaddq_f32(t03, t04); + float32x4_t s52 = vaddq_f32(t13, t14); + float32x4_t s53 = vaddq_f32(t23, t24); + float32x4_t s54 = vaddq_f32(t33, t34); + float32x4_t s55 = vaddq_f32(t43, t44); + float32x4_t s56 = vaddq_f32(t53, t54); + float32x4_t s57 = vaddq_f32(t63, t64); + + float32x4_t s61 = vaddq_f32(t05, t06); + float32x4_t s62 = vaddq_f32(t15, t16); + float32x4_t s63 = vaddq_f32(t25, t26); + float32x4_t s64 = vaddq_f32(t35, t36); + float32x4_t s65 = vaddq_f32(t45, t46); + float32x4_t s66 = vaddq_f32(t55, t56); + float32x4_t s67 = vaddq_f32(t65, t66); + + float32x4_t m00 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), t03), t04), t05), t06); + float32x4_t m01 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.5), s21), vmulq_n_f32(s31, 1.5)); + float32x4_t m02 = vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.25), s51), vmulq_n_f32(s61, 2.25)); + float32x4_t m03 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.125), s21), vmulq_n_f32(s31, 3.375)); + float32x4_t m04 = vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.0625), s51), vmulq_n_f32(s61, 5.0625)); + float32x4_t m05 = vaddq_f32(vaddq_f32(vmulq_n_f32(s11, 0.03125), s21), vmulq_n_f32(s31, 7.59375)); + float32x4_t m06 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s41, 0.015625), s51), vmulq_n_f32(s61, 11.390625)), t07); + + float32x4_t m10 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), t13), t14), t15), t16); + float32x4_t m11 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.5), s22), vmulq_n_f32(s32, 1.5)); + float32x4_t m12 = vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.25), s52), vmulq_n_f32(s62, 2.25)); + float32x4_t m13 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.125), s22), vmulq_n_f32(s32, 3.375)); + float32x4_t m14 = vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.0625), s52), vmulq_n_f32(s62, 5.0625)); + float32x4_t m15 = vaddq_f32(vaddq_f32(vmulq_n_f32(s12, 0.03125), s22), vmulq_n_f32(s32, 7.59375)); + float32x4_t m16 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s42, 0.015625), s52), vmulq_n_f32(s62, 11.390625)), t17); + + float32x4_t m20 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t20, t21), t22), t23), t24), t25), t26); + float32x4_t m21 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.5), s23), vmulq_n_f32(s33, 1.5)); + float32x4_t m22 = vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.25), s53), vmulq_n_f32(s63, 2.25)); + float32x4_t m23 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.125), s23), vmulq_n_f32(s33, 3.375)); + float32x4_t m24 = vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.0625), s53), vmulq_n_f32(s63, 5.0625)); + float32x4_t m25 = vaddq_f32(vaddq_f32(vmulq_n_f32(s13, 0.03125), s23), vmulq_n_f32(s33, 7.59375)); + float32x4_t m26 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s43, 0.015625), s53), vmulq_n_f32(s63, 11.390625)), t27); + + float32x4_t m30 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t30, t31), t32), t33), t34), t35), t36); + float32x4_t m31 = vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.5), s24), vmulq_n_f32(s34, 1.5)); + float32x4_t m32 = vaddq_f32(vaddq_f32(vmulq_n_f32(s44, 0.25), s54), vmulq_n_f32(s64, 2.25)); + float32x4_t m33 = vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.125), s24), vmulq_n_f32(s34, 3.375)); + float32x4_t m34 = vaddq_f32(vaddq_f32(vmulq_n_f32(s44, 0.0625), s54), vmulq_n_f32(s64, 5.0625)); + float32x4_t m35 = vaddq_f32(vaddq_f32(vmulq_n_f32(s14, 0.03125), s24), vmulq_n_f32(s34, 7.59375)); + float32x4_t m36 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s44, 0.015625), s54), vmulq_n_f32(s64, 11.390625)), t37); + + float32x4_t m40 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t40, t41), t42), t43), t44), t45), t46); + float32x4_t m41 = vaddq_f32(vaddq_f32(vmulq_n_f32(s15, 0.5), s25), vmulq_n_f32(s35, 1.5)); + float32x4_t m42 = vaddq_f32(vaddq_f32(vmulq_n_f32(s45, 0.25), s55), vmulq_n_f32(s65, 2.25)); + float32x4_t m43 = vaddq_f32(vaddq_f32(vmulq_n_f32(s15, 0.125), s25), vmulq_n_f32(s35, 3.375)); + float32x4_t m44 = vaddq_f32(vaddq_f32(vmulq_n_f32(s45, 0.0625), s55), vmulq_n_f32(s65, 5.0625)); + float32x4_t m45 = vaddq_f32(vaddq_f32(vmulq_n_f32(s15, 0.03125), s25), vmulq_n_f32(s35, 7.59375)); + float32x4_t m46 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s45, 0.015625), s55), vmulq_n_f32(s65, 11.390625)), t47); + + float32x4_t m50 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t50, t51), t52), t53), t54), t55), t56); + float32x4_t m51 = vaddq_f32(vaddq_f32(vmulq_n_f32(s16, 0.5), s26), vmulq_n_f32(s36, 1.5)); + float32x4_t m52 = vaddq_f32(vaddq_f32(vmulq_n_f32(s46, 0.25), s56), vmulq_n_f32(s66, 2.25)); + float32x4_t m53 = vaddq_f32(vaddq_f32(vmulq_n_f32(s16, 0.125), s26), vmulq_n_f32(s36, 3.375)); + float32x4_t m54 = vaddq_f32(vaddq_f32(vmulq_n_f32(s46, 0.0625), s56), vmulq_n_f32(s66, 5.0625)); + float32x4_t m55 = vaddq_f32(vaddq_f32(vmulq_n_f32(s16, 0.03125), s26), vmulq_n_f32(s36, 7.59375)); + float32x4_t m56 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s46, 0.015625), s56), vmulq_n_f32(s66, 11.390625)), t57); + + float32x4_t m60 = vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(vaddq_f32(t60, t61), t62), t63), t64), t65), t66); + float32x4_t m61 = vaddq_f32(vaddq_f32(vmulq_n_f32(s17, 0.5), s27), vmulq_n_f32(s37, 1.5)); + float32x4_t m62 = vaddq_f32(vaddq_f32(vmulq_n_f32(s47, 0.25), s57), vmulq_n_f32(s67, 2.25)); + float32x4_t m63 = vaddq_f32(vaddq_f32(vmulq_n_f32(s17, 0.125), s27), vmulq_n_f32(s37, 3.375)); + float32x4_t m64 = vaddq_f32(vaddq_f32(vmulq_n_f32(s47, 0.0625), s57), vmulq_n_f32(s67, 5.0625)); + float32x4_t m65 = vaddq_f32(vaddq_f32(vmulq_n_f32(s17, 0.03125), s27), vmulq_n_f32(s37, 7.59375)); + float32x4_t m66 = vaddq_f32(vaddq_f32(vaddq_f32(vmulq_n_f32(s47, 0.015625), s57), vmulq_n_f32(s67, 11.390625)), t67); + + float32x4_t bias_ptr = vld1q_f32(bias_data); + vst1q_f32(dst_data, vaddq_f32(m00, bias_ptr)); + vst1q_f32(dst_data + C4NUM, vaddq_f32(m01, bias_ptr)); + vst1q_f32(dst_data + 2 * C4NUM, vaddq_f32(m02, bias_ptr)); + vst1q_f32(dst_data + 3 * C4NUM, vaddq_f32(m03, bias_ptr)); + vst1q_f32(dst_data + 4 * C4NUM, vaddq_f32(m04, bias_ptr)); + vst1q_f32(dst_data + 5 * C4NUM, vaddq_f32(m05, bias_ptr)); + vst1q_f32(dst_data + 6 * C4NUM, vaddq_f32(m06, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM, vaddq_f32(m10, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, vaddq_f32(m11, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m12, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m13, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m14, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m15, bias_ptr)); + vst1q_f32(dst_data + dst_step * C4NUM + 6 * C4NUM, vaddq_f32(m16, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM, vaddq_f32(m20, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, vaddq_f32(m21, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m22, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m23, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m24, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m25, bias_ptr)); + vst1q_f32(dst_data + 2 * dst_step * C4NUM + 6 * C4NUM, vaddq_f32(m26, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM, vaddq_f32(m30, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + C4NUM, vaddq_f32(m31, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m32, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m33, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m34, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m35, bias_ptr)); + vst1q_f32(dst_data + 3 * dst_step * C4NUM + 6 * C4NUM, vaddq_f32(m36, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM, vaddq_f32(m40, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + C4NUM, vaddq_f32(m41, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m42, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m43, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m44, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m45, bias_ptr)); + vst1q_f32(dst_data + 4 * dst_step * C4NUM + 6 * C4NUM, vaddq_f32(m46, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM, vaddq_f32(m50, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + C4NUM, vaddq_f32(m51, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m52, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m53, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m54, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m55, bias_ptr)); + vst1q_f32(dst_data + 5 * dst_step * C4NUM + 6 * C4NUM, vaddq_f32(m56, bias_ptr)); + vst1q_f32(dst_data + 6 * dst_step * C4NUM, vaddq_f32(m60, bias_ptr)); + vst1q_f32(dst_data + 6 * dst_step * C4NUM + C4NUM, vaddq_f32(m61, bias_ptr)); + vst1q_f32(dst_data + 6 * dst_step * C4NUM + 2 * C4NUM, vaddq_f32(m62, bias_ptr)); + vst1q_f32(dst_data + 6 * dst_step * C4NUM + 3 * C4NUM, vaddq_f32(m63, bias_ptr)); + vst1q_f32(dst_data + 6 * dst_step * C4NUM + 4 * C4NUM, vaddq_f32(m64, bias_ptr)); + vst1q_f32(dst_data + 6 * dst_step * C4NUM + 5 * C4NUM, vaddq_f32(m65, bias_ptr)); + vst1q_f32(dst_data + 6 * dst_step * C4NUM + 6 * C4NUM, vaddq_f32(m66, bias_ptr)); +#else + for (int i = 0; i < C4NUM; i++) { + float src_data_00 = src_data[i]; + float src_data_01 = src_data[i + src_step]; + float src_data_02 = src_data[i + 2 * src_step]; + float src_data_03 = src_data[i + 3 * src_step]; + float src_data_04 = src_data[i + 4 * src_step]; + float src_data_05 = src_data[i + 5 * src_step]; + float src_data_06 = src_data[i + 6 * src_step]; + float src_data_07 = src_data[i + 7 * src_step]; + float src_data_10 = src_data[i + 8 * src_step]; + float src_data_11 = src_data[i + 9 * src_step]; + float src_data_12 = src_data[i + 10 * src_step]; + float src_data_13 = src_data[i + 11 * src_step]; + float src_data_14 = src_data[i + 12 * src_step]; + float src_data_15 = src_data[i + 13 * src_step]; + float src_data_16 = src_data[i + 14 * src_step]; + float src_data_17 = src_data[i + 15 * src_step]; + float src_data_20 = src_data[i + 16 * src_step]; + float src_data_21 = src_data[i + 17 * src_step]; + float src_data_22 = src_data[i + 18 * src_step]; + float src_data_23 = src_data[i + 19 * src_step]; + float src_data_24 = src_data[i + 20 * src_step]; + float src_data_25 = src_data[i + 21 * src_step]; + float src_data_26 = src_data[i + 22 * src_step]; + float src_data_27 = src_data[i + 23 * src_step]; + float src_data_30 = src_data[i + 24 * src_step]; + float src_data_31 = src_data[i + 25 * src_step]; + float src_data_32 = src_data[i + 26 * src_step]; + float src_data_33 = src_data[i + 27 * src_step]; + float src_data_34 = src_data[i + 28 * src_step]; + float src_data_35 = src_data[i + 29 * src_step]; + float src_data_36 = src_data[i + 30 * src_step]; + float src_data_37 = src_data[i + 31 * src_step]; + float src_data_40 = src_data[i + 32 * src_step]; + float src_data_41 = src_data[i + 33 * src_step]; + float src_data_42 = src_data[i + 34 * src_step]; + float src_data_43 = src_data[i + 35 * src_step]; + float src_data_44 = src_data[i + 36 * src_step]; + float src_data_45 = src_data[i + 37 * src_step]; + float src_data_46 = src_data[i + 38 * src_step]; + float src_data_47 = src_data[i + 39 * src_step]; + float src_data_50 = src_data[i + 40 * src_step]; + float src_data_51 = src_data[i + 41 * src_step]; + float src_data_52 = src_data[i + 42 * src_step]; + float src_data_53 = src_data[i + 43 * src_step]; + float src_data_54 = src_data[i + 44 * src_step]; + float src_data_55 = src_data[i + 45 * src_step]; + float src_data_56 = src_data[i + 46 * src_step]; + float src_data_57 = src_data[i + 47 * src_step]; + float src_data_60 = src_data[i + 48 * src_step]; + float src_data_61 = src_data[i + 49 * src_step]; + float src_data_62 = src_data[i + 50 * src_step]; + float src_data_63 = src_data[i + 51 * src_step]; + float src_data_64 = src_data[i + 52 * src_step]; + float src_data_65 = src_data[i + 53 * src_step]; + float src_data_66 = src_data[i + 54 * src_step]; + float src_data_67 = src_data[i + 55 * src_step]; + float src_data_70 = src_data[i + 56 * src_step]; + float src_data_71 = src_data[i + 57 * src_step]; + float src_data_72 = src_data[i + 58 * src_step]; + float src_data_73 = src_data[i + 59 * src_step]; + float src_data_74 = src_data[i + 60 * src_step]; + float src_data_75 = src_data[i + 61 * src_step]; + float src_data_76 = src_data[i + 62 * src_step]; + float src_data_77 = src_data[i + 63 * src_step]; + + float d01 = src_data_10 - src_data_20; + float d02 = src_data_11 - src_data_21; + float d03 = src_data_12 - src_data_22; + float d04 = src_data_13 - src_data_23; + float d05 = src_data_14 - src_data_24; + float d06 = src_data_15 - src_data_25; + float d07 = src_data_16 - src_data_26; + float d08 = src_data_17 - src_data_27; + + float d11 = src_data_30 - src_data_40; + float d12 = src_data_31 - src_data_41; + float d13 = src_data_32 - src_data_42; + float d14 = src_data_33 - src_data_43; + float d15 = src_data_34 - src_data_44; + float d16 = src_data_35 - src_data_45; + float d17 = src_data_36 - src_data_46; + float d18 = src_data_37 - src_data_47; + + float d21 = src_data_50 - src_data_60; + float d22 = src_data_51 - src_data_61; + float d23 = src_data_52 - src_data_62; + float d24 = src_data_53 - src_data_63; + float d25 = src_data_54 - src_data_64; + float d26 = src_data_55 - src_data_65; + float d27 = src_data_56 - src_data_66; + float d28 = src_data_57 - src_data_67; + + float d31 = src_data_10 + src_data_20; + float d32 = src_data_11 + src_data_21; + float d33 = src_data_12 + src_data_22; + float d34 = src_data_13 + src_data_23; + float d35 = src_data_14 + src_data_24; + float d36 = src_data_15 + src_data_25; + float d37 = src_data_16 + src_data_26; + float d38 = src_data_17 + src_data_27; + + float d41 = src_data_30 + src_data_40; + float d42 = src_data_31 + src_data_41; + float d43 = src_data_32 + src_data_42; + float d44 = src_data_33 + src_data_43; + float d45 = src_data_34 + src_data_44; + float d46 = src_data_35 + src_data_45; + float d47 = src_data_36 + src_data_46; + float d48 = src_data_37 + src_data_47; + + float d51 = src_data_50 + src_data_60; + float d52 = src_data_51 + src_data_61; + float d53 = src_data_52 + src_data_62; + float d54 = src_data_53 + src_data_63; + float d55 = src_data_54 + src_data_64; + float d56 = src_data_55 + src_data_65; + float d57 = src_data_56 + src_data_66; + float d58 = src_data_57 + src_data_67; + + float t00 = src_data_00 + src_data_10 + src_data_20 + src_data_30 + src_data_40 + src_data_50 + src_data_60; + float t01 = src_data_01 + src_data_11 + src_data_21 + src_data_31 + src_data_41 + src_data_51 + src_data_61; + float t02 = src_data_02 + src_data_12 + src_data_22 + src_data_32 + src_data_42 + src_data_52 + src_data_62; + float t03 = src_data_03 + src_data_13 + src_data_23 + src_data_33 + src_data_43 + src_data_53 + src_data_63; + float t04 = src_data_04 + src_data_14 + src_data_24 + src_data_34 + src_data_44 + src_data_54 + src_data_64; + float t05 = src_data_05 + src_data_15 + src_data_25 + src_data_35 + src_data_45 + src_data_55 + src_data_65; + float t06 = src_data_06 + src_data_16 + src_data_26 + src_data_36 + src_data_46 + src_data_56 + src_data_66; + float t07 = src_data_07 + src_data_17 + src_data_27 + src_data_37 + src_data_47 + src_data_57 + src_data_67; + + float t10 = 0.5f * d01 + d11 + 1.5f * d21; + float t11 = 0.5f * d02 + d12 + 1.5f * d22; + float t12 = 0.5f * d03 + d13 + 1.5f * d23; + float t13 = 0.5f * d04 + d14 + 1.5f * d24; + float t14 = 0.5f * d05 + d15 + 1.5f * d25; + float t15 = 0.5f * d06 + d16 + 1.5f * d26; + float t16 = 0.5f * d07 + d17 + 1.5f * d27; + float t17 = 0.5f * d08 + d18 + 1.5f * d28; + + float t20 = 0.25f * d31 + d41 + 2.25f * d51; + float t21 = 0.25f * d32 + d42 + 2.25f * d52; + float t22 = 0.25f * d33 + d43 + 2.25f * d53; + float t23 = 0.25f * d34 + d44 + 2.25f * d54; + float t24 = 0.25f * d35 + d45 + 2.25f * d55; + float t25 = 0.25f * d36 + d46 + 2.25f * d56; + float t26 = 0.25f * d37 + d47 + 2.25f * d57; + float t27 = 0.25f * d38 + d48 + 2.25f * d58; + + float t30 = 0.125f * d01 + d11 + 3.375f * d21; + float t31 = 0.125f * d02 + d12 + 3.375f * d22; + float t32 = 0.125f * d03 + d13 + 3.375f * d23; + float t33 = 0.125f * d04 + d14 + 3.375f * d24; + float t34 = 0.125f * d05 + d15 + 3.375f * d25; + float t35 = 0.125f * d06 + d16 + 3.375f * d26; + float t36 = 0.125f * d07 + d17 + 3.375f * d27; + float t37 = 0.125f * d08 + d18 + 3.375f * d28; + + float t40 = 0.0625f * d31 + d41 + 5.0625f * d51; + float t41 = 0.0625f * d32 + d42 + 5.0625f * d52; + float t42 = 0.0625f * d33 + d43 + 5.0625f * d53; + float t43 = 0.0625f * d34 + d44 + 5.0625f * d54; + float t44 = 0.0625f * d35 + d45 + 5.0625f * d55; + float t45 = 0.0625f * d36 + d46 + 5.0625f * d56; + float t46 = 0.0625f * d37 + d47 + 5.0625f * d57; + float t47 = 0.0625f * d38 + d48 + 5.0625f * d58; + + float t50 = 0.03125f * d01 + d11 + 7.59375f * d21; + float t51 = 0.03125f * d02 + d12 + 7.59375f * d22; + float t52 = 0.03125f * d03 + d13 + 7.59375f * d23; + float t53 = 0.03125f * d04 + d14 + 7.59375f * d24; + float t54 = 0.03125f * d05 + d15 + 7.59375f * d25; + float t55 = 0.03125f * d06 + d16 + 7.59375f * d26; + float t56 = 0.03125f * d07 + d17 + 7.59375f * d27; + float t57 = 0.03125f * d08 + d18 + 7.59375f * d28; + + float t60 = 0.015625f * d31 + d41 + 11.390625f * d51 + src_data_70; + float t61 = 0.015625f * d32 + d42 + 11.390625f * d52 + src_data_71; + float t62 = 0.015625f * d33 + d43 + 11.390625f * d53 + src_data_72; + float t63 = 0.015625f * d34 + d44 + 11.390625f * d54 + src_data_73; + float t64 = 0.015625f * d35 + d45 + 11.390625f * d55 + src_data_74; + float t65 = 0.015625f * d36 + d46 + 11.390625f * d56 + src_data_75; + float t66 = 0.015625f * d37 + d47 + 11.390625f * d57 + src_data_76; + float t67 = 0.015625f * d38 + d48 + 11.390625f * d58 + src_data_77; + + float s11 = t01 - t02; + float s12 = t11 - t12; + float s13 = t21 - t22; + float s14 = t31 - t32; + float s15 = t41 - t42; + float s16 = t51 - t52; + float s17 = t61 - t62; + + float s21 = t03 - t04; + float s22 = t13 - t14; + float s23 = t23 - t24; + float s24 = t33 - t34; + float s25 = t43 - t44; + float s26 = t53 - t54; + float s27 = t63 - t64; + + float s31 = t05 - t06; + float s32 = t15 - t16; + float s33 = t25 - t26; + float s34 = t35 - t36; + float s35 = t45 - t46; + float s36 = t55 - t56; + float s37 = t56 - t66; + + float s41 = t01 + t02; + float s42 = t11 + t12; + float s43 = t21 + t22; + float s44 = t31 + t32; + float s45 = t41 + t42; + float s46 = t51 + t52; + float s47 = t61 + t62; + + float s51 = t03 + t04; + float s52 = t13 + t14; + float s53 = t23 + t24; + float s54 = t33 + t34; + float s55 = t43 + t44; + float s56 = t53 + t54; + float s57 = t63 + t64; + + float s61 = t05 + t06; + float s62 = t15 + t16; + float s63 = t25 + t26; + float s64 = t35 + t36; + float s65 = t45 + t46; + float s66 = t55 + t56; + float s67 = t65 + t66; + + float m00 = t00 + t01 + t02 + t03 + t04 + t05 + t06; + float m01 = 0.5f * s11 + s21 + 1.5f * s31; + float m02 = 0.25f * s41 + s51 + 2.25f * s61; + float m03 = 0.125f * s11 + s21 + 3.375f * s31; + float m04 = 0.0625f * s41 + s51 + 5.0625f * s61; + float m05 = 0.03125f * s11 + s21 + 7.59375f * s31; + float m06 = 0.015625f * s41 + s51 + 11.390625f * s61 + t07; + + float m10 = t10 + t11 + t12 + t13 + t14 + t15 + t16; + float m11 = 0.5f * s12 + s22 + 1.5f * s32; + float m12 = 0.25f * s42 + s52 + 2.25f * s62; + float m13 = 0.125f * s12 + s22 + 3.375f * s32; + float m14 = 0.0625f * s42 + s52 + 5.0625f * s62; + float m15 = 0.03125f * s12 + s22 + 7.59375f * s32; + float m16 = 0.015625f * s42 + s52 + 11.390625f * s62 + t17; + + float m20 = t20 + t21 + t22 + t23 + t24 + t25 + t26; + float m21 = 0.5f * s13 + s23 + 1.5f * s33; + float m22 = 0.25f * s43 + s53 + 2.25f * s63; + float m23 = 0.125f * s13 + s23 + 3.375f * s33; + float m24 = 0.0625f * s43 + s53 + 5.0625f * s63; + float m25 = 0.03125f * s13 + s23 + 7.59375f * s33; + float m26 = 0.015625f * s43 + s53 + 11.390625f * s63 + t27; + + float m30 = t30 + t31 + t32 + t33 + t34 + t35 + t36; + float m31 = 0.5f * s14 + s24 + 1.5f * s34; + float m32 = 0.25f * s44 + s54 + 2.25f * s64; + float m33 = 0.125f * s14 + s24 + 3.375f * s34; + float m34 = 0.0625f * s44 + s54 + 5.0625f * s64; + float m35 = 0.03125f * s14 + s24 + 7.59375f * s34; + float m36 = 0.015625f * s44 + s54 + 11.390625f * s64 + t37; + + float m40 = t40 + t41 + t42 + t43 + t44 + t45 + t46; + float m41 = 0.5f * s15 + s25 + 1.5f * s35; + float m42 = 0.25f * s45 + s55 + 2.25f * s65; + float m43 = 0.125f * s15 + s25 + 3.375f * s35; + float m44 = 0.0625f * s45 + s55 + 5.0625f * s65; + float m45 = 0.03125f * s15 + s25 + 7.59375f * s35; + float m46 = 0.015625f * s45 + s55 + 11.390625f * s65 + t47; + + float m50 = t50 + t51 + t52 + t53 + t54 + t55 + t56; + float m51 = 0.5f * s16 + s26 + 1.5f * s36; + float m52 = 0.25f * s46 + s56 + 2.25f * s66; + float m53 = 0.125f * s16 + s26 + 3.375f * s36; + float m54 = 0.0625f * s46 + s56 + 5.0625f * s66; + float m55 = 0.03125f * s16 + s26 + 7.59375f * s36; + float m56 = 0.015625f * s46 + s56 + 11.390625f * s66 + t57; + + float m60 = t60 + t61 + t62 + t63 + t64 + t65 + t66; + float m61 = 0.5f * s17 + s27 + 1.5f * s37; + float m62 = 0.25f * s47 + s57 + 2.25f * s67; + float m63 = 0.125f * s17 + s27 + 3.375f * s37; + float m64 = 0.0625f * s47 + s57 + 5.0625f * s67; + float m65 = 0.03125f * s17 + s27 + 7.59375f * s37; + float m66 = 0.015625f * s47 + s57 + 11.390625f * s67 + t67; + + (dst_data + i)[0] = m00 + bias_data[i]; + (dst_data + i + C4NUM)[0] = m01 + bias_data[i]; + (dst_data + i + 2 * C4NUM)[0] = m02 + bias_data[i]; + (dst_data + i + 3 * C4NUM)[0] = m03 + bias_data[i]; + (dst_data + i + 4 * C4NUM)[0] = m04 + bias_data[i]; + (dst_data + i + 5 * C4NUM)[0] = m05 + bias_data[i]; + (dst_data + i + 6 * C4NUM)[0] = m06 + bias_data[i]; + + (dst_data + i + dst_step * C4NUM)[0] = m10 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + C4NUM)[0] = m11 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 2 * C4NUM)[0] = m12 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 3 * C4NUM)[0] = m13 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 4 * C4NUM)[0] = m14 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 5 * C4NUM)[0] = m15 + bias_data[i]; + (dst_data + i + dst_step * C4NUM + 6 * C4NUM)[0] = m16 + bias_data[i]; + + (dst_data + i + 2 * dst_step * C4NUM)[0] = m20 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + C4NUM)[0] = m21 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 2 * C4NUM)[0] = m22 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 3 * C4NUM)[0] = m23 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 4 * C4NUM)[0] = m24 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 5 * C4NUM)[0] = m25 + bias_data[i]; + (dst_data + i + 2 * dst_step * C4NUM + 6 * C4NUM)[0] = m26 + bias_data[i]; + + (dst_data + i + 3 * dst_step * C4NUM)[0] = m30 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + C4NUM)[0] = m31 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 2 * C4NUM)[0] = m32 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 3 * C4NUM)[0] = m33 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 4 * C4NUM)[0] = m34 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 5 * C4NUM)[0] = m35 + bias_data[i]; + (dst_data + i + 3 * dst_step * C4NUM + 6 * C4NUM)[0] = m36 + bias_data[i]; + + (dst_data + i + 4 * dst_step * C4NUM)[0] = m40 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + C4NUM)[0] = m41 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 2 * C4NUM)[0] = m42 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 3 * C4NUM)[0] = m43 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 4 * C4NUM)[0] = m44 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 5 * C4NUM)[0] = m45 + bias_data[i]; + (dst_data + i + 4 * dst_step * C4NUM + 6 * C4NUM)[0] = m46 + bias_data[i]; + + (dst_data + i + 5 * dst_step * C4NUM)[0] = m50 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + C4NUM)[0] = m51 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 2 * C4NUM)[0] = m52 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 3 * C4NUM)[0] = m53 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 4 * C4NUM)[0] = m54 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 5 * C4NUM)[0] = m55 + bias_data[i]; + (dst_data + i + 5 * dst_step * C4NUM + 6 * C4NUM)[0] = m56 + bias_data[i]; + + (dst_data + i + 6 * dst_step * C4NUM)[0] = m60 + bias_data[i]; + (dst_data + i + 6 * dst_step * C4NUM + C4NUM)[0] = m61 + bias_data[i]; + (dst_data + i + 6 * dst_step * C4NUM + 2 * C4NUM)[0] = m62 + bias_data[i]; + (dst_data + i + 6 * dst_step * C4NUM + 3 * C4NUM)[0] = m63 + bias_data[i]; + (dst_data + i + 6 * dst_step * C4NUM + 4 * C4NUM)[0] = m64 + bias_data[i]; + (dst_data + i + 6 * dst_step * C4NUM + 5 * C4NUM)[0] = m65 + bias_data[i]; + (dst_data + i + 6 * dst_step * C4NUM + 6 * C4NUM)[0] = m66 + bias_data[i]; + } +#endif +} + +// Reference to the paper "Fast Algorithms for Convolutional Neural Networks" +// Utilize cost model to compute performance gain. +// If the gain is greater than got from Im2col, winograd algorithm will be chosen. +int SelectOutputUnit(ConvParameter *conv_param) { + auto input_batch = conv_param->input_batch_; + auto kernel_h = conv_param->kernel_h_; + auto kernel_w = conv_param->kernel_w_; + auto in_channel = conv_param->input_channel_; + auto out_h = conv_param->output_h_; + auto out_w = conv_param->output_w_; + auto out_channel = conv_param->output_channel_; + int out_plane = out_h * out_w; + + int max_unit = ::sqrt((float)(out_plane)); + max_unit = max_unit > MIN_UNIT ? max_unit : MIN_UNIT; + max_unit = max_unit < MAX_UNIT ? max_unit : MAX_UNIT; + int output_unit = 1; + float ratio = 0.0f; + // cost of conventional convolution multiplications + float ori_cost = out_plane * out_channel * in_channel * kernel_h * kernel_w; + + for (int u = MIN_UNIT; u < max_unit; u++) { + auto input_unit = u + kernel_h - 1; + if (input_unit != 4 && input_unit != 8) { + continue; + } + // don't count filter transform cost, because it can be processed once offline. + float input_trans_unit_cost = 2 * input_unit * input_unit * input_unit * in_channel; + float gemm_unit_cost = input_unit * input_unit * in_channel * out_channel; + float output_trans_unit_cost = input_unit * u * (u + input_unit) * out_channel; + // equation (23) in papar + float winograd_cost = (input_trans_unit_cost + gemm_unit_cost + output_trans_unit_cost) * + (UP_DIV(out_w, u) * (UP_DIV(out_h, u))) * input_batch; + float reduce_rate = ori_cost / winograd_cost; + if (reduce_rate > ratio && reduce_rate > 1) { + ratio = reduce_rate; + output_unit = u; + } + } + // If output_unit is 1, then it is conventional convolution + return output_unit; +} + +InputTransformUnitFunc GetInputTransFunc(int input_unit) { + if (input_unit == 4) { + return InputTransform4x4Unit; + } else if (input_unit == 8) { + return InputTransform8x8Unit; + } else { + printf("Only support 4 or 8 for input unit."); + return nullptr; + } +} + +OutputTransformUnitFunc GetOutputTransFunc(int input_unit, int output_unit) { + if (input_unit == 4 && output_unit == 2) { + return OutputTransform4x2Unit; + } else if (input_unit == 4 && output_unit == 3) { + return OutputTransform4x3Unit; + } else if (input_unit == 8) { + return outputTransformUnit[output_unit]; + } else { + printf("."); + return nullptr; + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_utils.h b/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_utils.h new file mode 100644 index 00000000000..55425a4a3e7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_utils.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_WINOGRAD_UTILS_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_WINOGRAD_UTILS_H_ + +#ifdef ENABLE_ARM +#include +#endif +#include "src/runtime/kernel/arm/opclib/matrix_table.h" +#include "src/runtime/kernel/arm/opclib/conv_parameter.h" +#include "src/runtime/kernel/arm/opclib/op_base.h" + +using InputTransformUnitFunc = void (*)(const float *src_data, float *dst_data, int src_step, int dst_step); +using OutputTransformUnitFunc = void (*)(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step); + +void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step); + +void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step); + +void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); + +void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); + +void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); + +void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); + +void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); + +void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); + +void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); + +void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step); + +int SelectOutputUnit(ConvParameter *conv_param); + +InputTransformUnitFunc GetInputTransFunc(int input_unit); + +OutputTransformUnitFunc GetOutputTransFunc(int input_unit, int output_unit); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_WINOGRAD_UTILS_H_ + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/zeroslike.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/zeroslike.cc new file mode 100644 index 00000000000..9d206e2cc50 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/zeroslike.cc @@ -0,0 +1,21 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/opclib/zeroslike.h" +#include +#include + +void ApproximateZerosLike(float *input, float *output, int number) { memset(output, 0.0, number * sizeof(float)); } + diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/zeroslike.h b/mindspore/lite/src/runtime/kernel/arm/opclib/zeroslike.h new file mode 100644 index 00000000000..b2e2fa35a44 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/zeroslike.h @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ZEROSLIKE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ZEROSLIKE_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +void ApproximateZerosLike(float *input, float *output, int number); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_ZEROSLIKE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt new file mode 100644 index 00000000000..6fae2e3a768 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt @@ -0,0 +1,12 @@ +set(OPENCL_KERNEL_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_opencl_kernel.cc + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/arithmetic.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/convolution.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/depthwise_conv2d.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/pooling2d.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/matmul.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/softmax.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/concat.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/conv2d_transpose.cc + ) diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/conv2d_transpose2x2.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/conv2d_transpose2x2.cl new file mode 100644 index 00000000000..fa70a06c731 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/conv2d_transpose2x2.cl @@ -0,0 +1,52 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#define FLT half +#define FLT4 half4 +#define FLT16 half16 +__kernel void conv2d_transpose2x2(__global FLT4 *inputx, __global FLT16 *weight, __global FLT4 *bias, + __global FLT4 *output, int2 kernel_size, int2 stride, int2 padding, int4 src_size, + int4 dst_size) { + int h = get_global_id(0); + int w = get_global_id(1); + int co = get_global_id(2); + if (h * 2 >= dst_size.x || w * 2 >= dst_size.y || co >= dst_size.z) return; + FLT4 r0 = (FLT4)(0.f); + FLT4 r1 = (FLT4)(0.f); + FLT4 r2 = (FLT4)(0.f); + FLT4 r3 = (FLT4)(0.f); + int base_x = (h * src_size.y + w) * src_size.z; + int base_w = co * src_size.z; + for (int ci = 0; ci < src_size.z; ++ci) { + FLT4 x = inputx[base_x + ci]; + FLT16 w0 = weight[(base_w + ci) * 4]; + FLT16 w1 = weight[(base_w + ci) * 4 + 1]; + FLT16 w2 = weight[(base_w + ci) * 4 + 2]; + FLT16 w3 = weight[(base_w + ci) * 4 + 3]; + r0 += x.x * w0.s0123; + r0 += x.y * w0.s4567; + r0 += x.z * w0.s89ab; + r0 += x.w * w0.scdef; + + r1 += x.x * w1.s0123; + r1 += x.y * w1.s4567; + r1 += x.z * w1.s89ab; + r1 += x.w * w1.scdef; + + r2 += x.x * w2.s0123; + r2 += x.y * w2.s4567; + r2 += x.z * w2.s89ab; + r2 += x.w * w2.scdef; + + r3 += x.x * w3.s0123; + r3 += x.y * w3.s4567; + r3 += x.z * w3.s89ab; + r3 += x.w * w3.scdef; + } + r0 += bias[co]; + r1 += bias[co]; + r2 += bias[co]; + r3 += bias[co]; + output[((2 * h + 0) * dst_size.y + 2 * w + 0) * dst_size.z + co] = r0; + output[((2 * h + 0) * dst_size.y + 2 * w + 1) * dst_size.z + co] = r1; + output[((2 * h + 1) * dst_size.y + 2 * w + 0) * dst_size.z + co] = r2; + output[((2 * h + 1) * dst_size.y + 2 * w + 1) * dst_size.z + co] = r3; +} \ No newline at end of file diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/depthwise_conv2d.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/depthwise_conv2d.cl new file mode 100644 index 00000000000..7e327ba0a3a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/depthwise_conv2d.cl @@ -0,0 +1,96 @@ +#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#define ACCUM_FLT4 half4 +#define FLT half +#define FLT2 half2 +#define FLT3 half3 +#define FLT4 half4 +#define TO_FLT4 convert_half4 +#define TO_ACCUM_TYPE convert_half4 +#define TO_ACCUM_FLT convert_half +#define READ_IMAGE read_imagef +#define WRITE_IMAGE write_imagef +__constant sampler_t smp_edge = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST; +__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void DepthwiseConv2d_NC4HW4( +__global FLT4* src_data, + __global FLT4* filters, +__global FLT4* biases, + float relu_clip1, +__global FLT4* dst_data, + int2 kernel_size, + int2 stride, + int2 padding, + int2 dilation, + int4 src_size, + int4 dst_size +) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; + ACCUM_FLT4 r = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + int x_offseted = X * stride.x + padding.x; + int y_offseted = Y * stride.y + padding.y; + int fx_c = Z * kernel_size.x * kernel_size.y; + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = y_offseted + ky * dilation.y; + bool outside_y = y_c < 0 || y_c >= src_size.y; + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = x_offseted + kx * dilation.x; + bool outside_x = x_c < 0 || x_c >= src_size.x; + if (!outside_x && !outside_y) { + FLT4 f = filters[fx_c]; + FLT4 src_final =src_data[(((Z) * src_size.y + (y_c)) * src_size.x + (x_c))]; + r += TO_ACCUM_TYPE(src_final * f); + }; + fx_c++; + } + } + FLT4 bias_val = biases[Z]; + FLT4 res0 = TO_FLT4(r) + bias_val; + res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); + dst_data[(((Z) * dst_size.y + (Y)) * dst_size.x + (X))] = res0; +} + +__kernel void DepthwiseConv2d_NHWC4( +__global FLT4* src_data, + __global FLT4* filters, +__global FLT4* biases, + float relu_clip1, +__global FLT4* dst_data, + int2 kernel_size, + int2 stride, + int2 padding, + int2 dilation, + int4 src_size, + int4 dst_size +) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; + ACCUM_FLT4 r = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + int x_offseted = X * stride.x + padding.x; + int y_offseted = Y * stride.y + padding.y; + int fx_c = Z * kernel_size.x * kernel_size.y; + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = y_offseted + ky * dilation.y; + bool outside_y = y_c < 0 || y_c >= src_size.y; + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = x_offseted + kx * dilation.x; + bool outside_x = x_c < 0 || x_c >= src_size.x; + if (!outside_x && !outside_y) { + FLT4 f = filters[fx_c]; + FLT4 src_final =src_data[((y_c * src_size.x + x_c) * src_size.z + Z)]; + r += TO_ACCUM_TYPE(src_final * f); + }; + fx_c++; + } + } + FLT4 bias_val = biases[Z]; + FLT4 res0 = TO_FLT4(r) + bias_val; + res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); + dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res0; +} \ No newline at end of file diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/matmul.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/matmul.cl new file mode 100644 index 00000000000..10c66fc3626 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/matmul.cl @@ -0,0 +1,32 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#define FLT4 half4 +#define FLT16 half16 +__kernel void MatMul(__global FLT4 *x, __global FLT16 *weight, + __global FLT4 *buffer, __global FLT4 *bias, int2 offset_ci, + int2 offset_co, int has_bias) { + int2 gid = (int2)(get_global_id(0), get_global_id(1)); + int2 lid = (int2)(get_local_id(0), get_local_id(1)); + FLT4 s = (FLT4)(0.0f); + bool inside = gid.x < offset_co.y; + for (uint i = lid.y; i < offset_ci.y && inside; i += 4) { + FLT4 v = x[i]; + FLT16 w = weight[gid.x + i * offset_co.y]; + s.x += dot(v, w.s0123); + s.y += dot(v, w.s4567); + s.z += dot(v, w.s89ab); + s.w += dot(v, w.scdef); + } + __local FLT4 temp[64][4]; + temp[lid.x][lid.y] = s; + barrier(CLK_LOCAL_MEM_FENCE); + if (lid.y == 0 && inside) { + s += temp[lid.x][1]; + s += temp[lid.x][2]; + s += temp[lid.x][3]; + if (has_bias != 0) { + s += bias[gid.x]; + } + buffer[gid.x] = s; + // memory pollution? or protected by opencl + } +} \ No newline at end of file diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/arithmetic.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/arithmetic.cl new file mode 100644 index 00000000000..255f8b3623f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/arithmetic.cl @@ -0,0 +1,49 @@ +__kernel void ArithmeticAdd(__global float *input_a, + __global float *input_b, + __global float *output, + const unsigned int n) { + int id = get_global_id(0); + if (id < n) { + output[id] = input_a[id] + input_b[id]; + } +} +__kernel void ArithmeticSub(__global float *input_a, + __global float *input_b, + __global float *output, + const unsigned int n) { + int id = get_global_id(0); + if (id < n) { + output[id] = input_a[id] - input_b[id]; + } +} +__kernel void ArithmeticMul(__global float *input_a, + __global float *input_b, + __global float *output, + const unsigned int n) { + int id = get_global_id(0); + if (id < n) { + output[id] = input_a[id] * input_b[id]; + } +} +__kernel void ArithmeticDiv(__global float *input_a, + __global float *input_b, + __global float *output, + const unsigned int n) { + int id = get_global_id(0); + if (id < n) { + output[id] = input_a[id] * input_b[id]; + } +} + +__kernel void ArithmeticBiasAdd(__global float4 *input, + __global float4 *output, + const float weight, + const float bias, + const unsigned int n) { + int id = get_global_id(0); + float4 bias_vec = (float4)(bias, 0.0f, .0f, .0f); + float4 weight_vec = (float4)(weight, 0.0f, .0f, .0f); + if (id < n) { + output[id] = weight_vec * input[id] + bias_vec; + } +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/avg_pool2d.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/avg_pool2d.cl new file mode 100644 index 00000000000..1d2ca1216c7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/avg_pool2d.cl @@ -0,0 +1,66 @@ +__kernel void AvgPooling2d(__global float4 *input, __global float4 *output, const int4 input_shape, + const int4 output_shape, const int2 stride, const int2 kernel_size, const int2 padding) { + // axis to dst tensor coordinate + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + + // boundary check + if (X >= output_shape.x || Y >= output_shape.y || Z >= output_shape.w) { + return; + } + + float4 r = (float4)(0.0f); + float window_size = 0.0f; + int xs = X * stride.x + padding.x; + int ys = Y * stride.y + padding.y; + + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = xs + kx; + bool outside_x = x_c < 0 || x_c >= input_shape.x; + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = ys + ky; + bool outside = outside_x || y_c < 0 || y_c >= input_shape.y; + r += !outside ? input[(input_shape.y * x_c + y_c) * output_shape.w + Z] : (float4)(0.0f); + window_size += !outside ? 1.0f : 0.0f; + } + } + float4 result = convert_float4(r / window_size); + output[(output_shape.y * X + Y) * output_shape.w + Z] = result; +} + +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + +__kernel void AvgPooling2dImage2d(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, + const int4 output_shape, const int2 stride, const int2 kernel_size, + const int2 padding) { + // axis to dst tensor coordinate + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + + // boundary check + if (X >= output_shape.x || Y >= output_shape.y || Z >= output_shape.w) { + return; + } + + float4 r = (float4)(0.0f); + float window_size = 0.0f; + int xs = X * stride.x + padding.x; + int ys = Y * stride.y + padding.y; + + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = xs + kx; + bool outside_x = x_c < 0 || x_c >= input_shape.x; + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = ys + ky; + bool outside = outside_x || y_c < 0 || y_c >= input_shape.y; + + r += read_imagef(input, smp_zero, (int2)(x_c, y_c * input_shape.w + Z)); + window_size += !outside ? 1.0f : 0.0f; + } + } + float4 result = convert_float4(r / window_size); + write_imagef(output, (int2)(X, Y * output_shape.w + Z), result); +} \ No newline at end of file diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/concat.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/concat.cl new file mode 100644 index 00000000000..9525863c6f5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/concat.cl @@ -0,0 +1,60 @@ +//#pragma OPENCL EXTENSION cl_khr_fp16 : enable +__kernel void Concat(__global float *input0, __global float *input1, __global float *output, const int4 input_shape0, + const int4 input_shape1, const int4 output_shape, const int axis) { + int postion = 0, index_input_shape0 = 0, index_input_shape1 = 0; + switch (axis) { + case 1: + for (int i = 0; i < output_shape.x; i++) { + for (int j = 0; j < output_shape.y; j++) { + for (int k = 0; k < output_shape.z; k++) { + for (int w = 0; w < output_shape.w; w++) { + postion = i * output_shape.y * output_shape.z * output_shape.w + j * output_shape.z * output_shape.w + + k * output_shape.w + w; + if (j < input_shape0.y) { + output[postion] = input0[index_input_shape0++]; + } else { + output[postion] = input1[index_input_shape1++]; + } + } + } + } + } + break; + case 2: + for (int i = 0; i < output_shape.x; i++) { + for (int j = 0; j < output_shape.y; j++) { + for (int k = 0; k < output_shape.z; k++) { + for (int w = 0; w < output_shape.w; w++) { + postion = i * output_shape.y * output_shape.z * output_shape.w + j * output_shape.z * output_shape.w + + k * output_shape.w + w; + if (k < input_shape0.z) { + output[postion] = input0[index_input_shape0++]; + } else { + output[postion] = input1[index_input_shape1++]; + } + } + } + } + } + break; + case 3: + for (int i = 0; i < output_shape.x; i++) { + for (int j = 0; j < output_shape.y; j++) { + for (int k = 0; k < output_shape.z; k++) { + for (int w = 0; w < output_shape.w; w++) { + postion = i * output_shape.y * output_shape.z * output_shape.w + j * output_shape.z * output_shape.w + + k * output_shape.w + w; + if (w < input_shape0.w) { + output[postion] = input0[index_input_shape0++]; + } else { + output[postion] = input1[index_input_shape1++]; + } + } + } + } + } + break; + default: + break; + } +} \ No newline at end of file diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/conv2d_transpose2x2.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/conv2d_transpose2x2.cl new file mode 100644 index 00000000000..d757d926bd9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/conv2d_transpose2x2.cl @@ -0,0 +1,51 @@ +#define FLT float +#define FLT4 float4 +#define FLT16 float16 +__kernel void conv2d_transpose2x2(__global FLT4 *inputx, __global FLT16 *weight, __global FLT4 *bias, + __global FLT4 *output, int2 kernel_size, int2 stride, int2 padding, int4 src_size, + int4 dst_size) { + int h = get_global_id(0); + int w = get_global_id(1); + int co = get_global_id(2); + if (h * 2 >= dst_size.x || w * 2 >= dst_size.y || co >= dst_size.z) return; + FLT4 r0 = (FLT4)(0.f); + FLT4 r1 = (FLT4)(0.f); + FLT4 r2 = (FLT4)(0.f); + FLT4 r3 = (FLT4)(0.f); + int base_x = (h * src_size.y + w) * src_size.z; + int base_w = co * src_size.z; + for (int ci = 0; ci < src_size.z; ++ci) { + FLT4 x = inputx[base_x + ci]; + FLT16 w0 = weight[(base_w + ci) * 4]; + FLT16 w1 = weight[(base_w + ci) * 4 + 1]; + FLT16 w2 = weight[(base_w + ci) * 4 + 2]; + FLT16 w3 = weight[(base_w + ci) * 4 + 3]; + r0 += x.x * w0.s0123; + r0 += x.y * w0.s4567; + r0 += x.z * w0.s89ab; + r0 += x.w * w0.scdef; + + r1 += x.x * w1.s0123; + r1 += x.y * w1.s4567; + r1 += x.z * w1.s89ab; + r1 += x.w * w1.scdef; + + r2 += x.x * w2.s0123; + r2 += x.y * w2.s4567; + r2 += x.z * w2.s89ab; + r2 += x.w * w2.scdef; + + r3 += x.x * w3.s0123; + r3 += x.y * w3.s4567; + r3 += x.z * w3.s89ab; + r3 += x.w * w3.scdef; + } + r0 += bias[co]; + r1 += bias[co]; + r2 += bias[co]; + r3 += bias[co]; + output[((2 * h + 0) * dst_size.y + 2 * w + 0) * dst_size.z + co] = r0; + output[((2 * h + 0) * dst_size.y + 2 * w + 1) * dst_size.z + co] = r1; + output[((2 * h + 1) * dst_size.y + 2 * w + 0) * dst_size.z + co] = r2; + output[((2 * h + 1) * dst_size.y + 2 * w + 1) * dst_size.z + co] = r3; +} \ No newline at end of file diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/convolution.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/convolution.cl new file mode 100644 index 00000000000..af7f858a87c --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/convolution.cl @@ -0,0 +1,87 @@ +#define CI_TILE 4 +#define CO_TILE 4 + +#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) + +//#pragma OPENCL EXTENSION cl_arm_printf : enable +__kernel void convolution_NHWC_OHWI(__global float *input, + __global float *weight, + __global float *bias, + __global float *output, + const uint4 input_shape, // NHWC + const uint4 weight_shape, // OHWI + const uint4 output_shape, // NHWC + const uint2 stride, // HW + const uint4 pad) // top bottom left right +{ + uint ow = get_global_id(0); + uint oh = get_global_id(1); + uint co_outer = get_global_id(2); + + uint CI = input_shape.w, IH = input_shape.y, IW = input_shape.z; + uint CO = output_shape.w, OW = output_shape.z; + uint KH = weight_shape.y, KW = weight_shape.z; + uint stride_h = stride.x, stride_w = stride.y; + uint pad_top = pad.x, pad_left = pad.z; + uint CI_TILE_NUM = UP_DIV(CI, CI_TILE); + uint CO_TILE_NUM = UP_DIV(CO, CO_TILE); + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + for (uint kh = 0; kh < KH; ++kh) + { + uint ih = kh + oh * stride_h - pad_top; + for (uint kw = 0; kw < KW; ++kw) + { + uint iw = kw + ow * stride_w - pad_left; + for (uint ci_outer = 0; ci_outer < CI_TILE_NUM; ++ci_outer) + { + for (uint ci_inner = 0; ci_inner < CI_TILE; ++ci_inner) + { + uint ci = ci_outer * CI_TILE + ci_inner; + if (ci >= CI) + break; + + uint input_idx = ih * IW * CI + iw * CI + ci; + float value = 0; + if (ih < 0 || ih >= IH || iw < 0 || iw >= IW) + value = 0; + else + value = input[input_idx]; + + uint CO_TILE_OFFSET = KH * KW * CI; + uint weight_idx = (co_outer * CO_TILE) * CO_TILE_OFFSET + + kh * KW * CI + + kw * CI + + ci; + acc.x += weight[weight_idx + 0 * CO_TILE_OFFSET] * value; + acc.y += weight[weight_idx + 1 * CO_TILE_OFFSET] * value; + acc.z += weight[weight_idx + 2 * CO_TILE_OFFSET] * value; + acc.w += weight[weight_idx + 3 * CO_TILE_OFFSET] * value; + } + } + } + } + uint output_idx = oh * OW * CO + ow * CO + (co_outer * CO_TILE); + if (co_outer < CO_TILE_NUM - 1 || CO % CO_TILE == 0) + { + output[output_idx + 0] = acc.x + bias[co_outer * CO_TILE + 0]; + output[output_idx + 1] = acc.y + bias[co_outer * CO_TILE + 1]; + output[output_idx + 2] = acc.z + bias[co_outer * CO_TILE + 2]; + output[output_idx + 3] = acc.w + bias[co_outer * CO_TILE + 3]; + } + else if (CO % CO_TILE == 1) + { + output[output_idx + 0] = acc.x + bias[co_outer * CO_TILE + 0]; + } + else if (CO % CO_TILE == 2) + { + output[output_idx + 0] = acc.x + bias[co_outer * CO_TILE + 0]; + output[output_idx + 1] = acc.y + bias[co_outer * CO_TILE + 1]; + } + else if (CO % CO_TILE == 3) + { + output[output_idx + 0] = acc.x + bias[co_outer * CO_TILE + 0]; + output[output_idx + 1] = acc.y + bias[co_outer * CO_TILE + 1]; + output[output_idx + 2] = acc.z + bias[co_outer * CO_TILE + 2]; + } +} \ No newline at end of file diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/depthwise_conv2d.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/depthwise_conv2d.cl new file mode 100644 index 00000000000..4bc003799aa --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/depthwise_conv2d.cl @@ -0,0 +1,95 @@ +#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable +#define ACCUM_FLT4 float4 +#define FLT float +#define FLT2 float2 +#define FLT3 float3 +#define FLT4 float4 +#define TO_FLT4 convert_float4 +#define TO_ACCUM_TYPE convert_float4 +#define TO_ACCUM_FLT convert_float +#define READ_IMAGE read_imagef +#define WRITE_IMAGE write_imagef +__constant sampler_t smp_edge = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST; +__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void DepthwiseConv2d_NC4HW4( +__global float4* src_data, + __global FLT4* filters, +__global FLT4* biases, + float relu_clip1, +__global float4* dst_data, + int2 kernel_size, + int2 stride, + int2 padding, + int2 dilation, + int4 src_size, + int4 dst_size +) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; + ACCUM_FLT4 r = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + int x_offseted = X * stride.x + padding.x; + int y_offseted = Y * stride.y + padding.y; + int fx_c = Z * kernel_size.x * kernel_size.y; + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = y_offseted + ky * dilation.y; + bool outside_y = y_c < 0 || y_c >= src_size.y; + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = x_offseted + kx * dilation.x; + bool outside_x = x_c < 0 || x_c >= src_size.x; + if (!outside_x && !outside_y) { + FLT4 f = filters[fx_c]; + FLT4 src_final =src_data[(((Z) * src_size.y + (y_c)) * src_size.x + (x_c))]; + r += TO_ACCUM_TYPE(src_final * f); + }; + fx_c++; + } + } + FLT4 bias_val = biases[Z]; + FLT4 res0 = TO_FLT4(r) + bias_val; + res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); + dst_data[(((Z) * dst_size.y + (Y)) * dst_size.x + (X))] = res0; +} + +__kernel void DepthwiseConv2d_NHWC4( +__global float4* src_data, + __global FLT4* filters, +__global FLT4* biases, + float relu_clip1, +__global float4* dst_data, + int2 kernel_size, + int2 stride, + int2 padding, + int2 dilation, + int4 src_size, + int4 dst_size +) { + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return; + ACCUM_FLT4 r = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f); + int x_offseted = X * stride.x + padding.x; + int y_offseted = Y * stride.y + padding.y; + int fx_c = Z * kernel_size.x * kernel_size.y; + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = y_offseted + ky * dilation.y; + bool outside_y = y_c < 0 || y_c >= src_size.y; + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = x_offseted + kx * dilation.x; + bool outside_x = x_c < 0 || x_c >= src_size.x; + if (!outside_x && !outside_y) { + FLT4 f = filters[fx_c]; + FLT4 src_final =src_data[((y_c * src_size.x + x_c) * src_size.z + Z)]; + r += TO_ACCUM_TYPE(src_final * f); + }; + fx_c++; + } + } + FLT4 bias_val = biases[Z]; + FLT4 res0 = TO_FLT4(r) + bias_val; + res0 = clamp(res0, (FLT)(0.0f), (FLT)(relu_clip1)); + dst_data[((Y * dst_size.x + X) * dst_size.z + Z)] = res0; +} \ No newline at end of file diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/matmul.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/matmul.cl new file mode 100644 index 00000000000..c08e3941275 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/matmul.cl @@ -0,0 +1,31 @@ +#define FLT4 float4 +#define FLT16 float16 +__kernel void MatMul(__global FLT4 *x, __global FLT16 *weight, + __global FLT4 *buffer, __global FLT4 *bias, int2 offset_ci, + int2 offset_co, int has_bias) { + int2 gid = (int2)(get_global_id(0), get_global_id(1)); + int2 lid = (int2)(get_local_id(0), get_local_id(1)); + FLT4 s = (FLT4)(0.0f); + bool inside = gid.x < offset_co.y; + for (uint i = lid.y; i < offset_ci.y && inside; i += 4) { + FLT4 v = x[i]; + FLT16 w = weight[gid.x + i * offset_co.y]; + s.x += dot(v, w.s0123); + s.y += dot(v, w.s4567); + s.z += dot(v, w.s89ab); + s.w += dot(v, w.scdef); + } + __local FLT4 temp[64][4]; + temp[lid.x][lid.y] = s; + barrier(CLK_LOCAL_MEM_FENCE); + if (lid.y == 0 && inside) { + s += temp[lid.x][1]; + s += temp[lid.x][2]; + s += temp[lid.x][3]; + if (has_bias != 0) { + s += bias[gid.x]; + } + buffer[gid.x] = s; + // memory pollution? or protected by opencl + } +} \ No newline at end of file diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/max_pool2d.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/max_pool2d.cl new file mode 100644 index 00000000000..0d61ec32bab --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/max_pool2d.cl @@ -0,0 +1,68 @@ +__kernel void MaxPooling2d(__global float4 *input, __global float4 *output, const int4 input_shape, + const int4 output_shape, const int2 stride, const int2 kernel_size, const int2 padding) { + // axis to dst tensor coordinate + int X = get_global_id(0); + int Y = get_global_id(1); + int Z = get_global_id(2); + + // boundary check + if (X >= output_shape.x || Y >= output_shape.y || Z >= output_shape.w) { + return; + } + + float4 maximum = (float4)(-10000.0f); + int xs = X * stride.x + padding.x; + int ys = Y * stride.y + padding.y; + + for (int kx = 0; kx < kernel_size.x; ++kx) { + int x_c = xs + kx; + if (x_c < 0 || x_c >= input_shape.x) { + continue; + } + for (int ky = 0; ky < kernel_size.y; ++ky) { + int y_c = ys + ky; + if (y_c < 0 || y_c >= input_shape.y) { + continue; + } + float4 src = input[(input_shape.y * x_c + y_c) * input_shape.w + Z]; + maximum = max(src, maximum); + } + } + output[(output_shape.y * X + Y) * output_shape.w + Z] = maximum; +} + +// __constant sampler_t sample_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; + +//__kernel void MaxPooling2dImage2d(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, +// const int4 output_shape, const int2 stride, const int2 kernel_size, +// const int2 padding) { +// // axis to dst tensor coordinate +// int X = get_global_id(0); +// int Y = get_global_id(1); +// int Z = get_global_id(2); +// +// // boundary check +// if (X >= output_shape.x || Y >= output_shape.y || Z >= output_shape.w) { +// return; +// } +// +// float4 maximum = (float4)(-10000.0f); +// int xs = X * stride.x + padding.x; +// int ys = Y * stride.y + padding.y; +// +// for (int ky = 0; ky < kernel_size.y; ++ky) { +// int y_c = ys + ky; +// if (y_c < 0 || y_c >= input_shape.y) { +// continue; +// } +// for (int kx = 0; kx < kernel_size.x; ++kx) { +// int x_c = xs + kx; +// if (x_c < 0 || x_c >= input_shape.x) { +// continue; +// } +// float4 src = read_imagef(input, sample_none, (int2)(x_c, y_c * input_shape.w + Z)); +// maximum = max(src, maximum); +// } +// } +// write_imagef(output, (int2)(X, Y * output_shape.w + Z), maximum); +//} \ No newline at end of file diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/softmax.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/softmax.cl new file mode 100644 index 00000000000..a1649c3dcf8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/softmax.cl @@ -0,0 +1,35 @@ +#define SLICES 4 + +int DivideRoundUp(int n, int div) +{ + int q = n / div; + return n % div == 0 ? q : q + 1; +} + +__kernel void SoftMax(__global float4 *input, + __global float4 *output, + const int4 input_shape) { + int X = get_global_id(0); // width + int Y = get_global_id(1); // height + int H = input_shape.y; + int W = input_shape.z; + int C = input_shape.w; + + if (X >= W || Y >= H) return; + + float sum = 0.0f; + for (int d = 0; d < DivideRoundUp(C, SLICES); ++d) { + float4 t = input[(Y * W + X * H) * C + d]; + sum += exp(t.x); + if (d * 4 + 1 < C) sum += exp(t.y); + if (d * 4 + 2 < C) sum += exp(t.z); + if (d * 4 + 3 < C) sum += exp(t.w); + } + + for (int d = 0; d < DivideRoundUp(C, SLICES); ++d) { + float4 t = input[(Y * W + X * H) * C + d]; + t = exp(t) / sum; + float4 result = convert_float4(t); + output[(Y * W + X * H) * C + d] = result; + } +} \ No newline at end of file diff --git a/mindspore/lite/src/runtime/kernel/opencl/image_format.h b/mindspore/lite/src/runtime/kernel/opencl/image_format.h new file mode 100644 index 00000000000..4987afe49da --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/image_format.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_IMAGE_FORMAT_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_IMAGE_FORMAT_H_ + +#include "src/runtime/opencl/opencl_runtime.h" + +namespace mindspore { +namespace kernel { + +/** + * MindSpore to OpenCL channel order. + * @param num_channels + * @return opencl_channels + */ +cl_channel_order ToChannelOrder(int num_channels) { + switch (num_channels) { + case 1: + return CL_R; + case 2: + return CL_RG; + case 3: + return CL_RGB; + case 4: + return CL_RGBA; + default: + return -1; + } +} + +/** + * MindSpore image channel type to OpenCL channel data type. + * @param data_type + * @return opencl_data_type + */ +cl_channel_type ToImageChannelType(TypeId data_type) { + switch (data_type) { + case kNumberTypeFloat32: + return CL_FLOAT; + case kNumberTypeFloat16: + return CL_HALF_FLOAT; + default: + return -1; + } +} +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_IMAGE_FORMAT_H_ + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc new file mode 100644 index 00000000000..78fb7d3a02f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/kernel/arithmetic.h" +#include +#include +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#ifndef PROGRAM_WITH_IL +#include "src/runtime/kernel/opencl/cl/fp32/arithmetic.cl.inc" +#endif + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; + +namespace mindspore::kernel { + +int ArithmeticOpenCLKernel::Init() { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + std::string kernel_name = "ArithmeticAdd"; + + is_bias_add_ = false; + if (inputs_[1]->TensorType() == schema::NodeType_ValueNode && inputs_[1]->Data() != nullptr) { + kernel_name = "ArithmeticBiasAdd"; + is_bias_add_ = true; + } + + switch (opParameter->type_) { + case PrimitiveType_Mul: + if (is_bias_add_) { + weight_ = static_cast(inputs_[1]->Data())[0]; + break; + } + kernel_name = "ArithmeticMul"; + break; + case PrimitiveType_Add: + if (is_bias_add_) { + bias_ = static_cast(inputs_[1]->Data())[0]; + break; + } + kernel_name = "ArithmeticAdd"; + break; + case PrimitiveType_Sub: + if (is_bias_add_) { + bias_ = -1 * static_cast(inputs_[1]->Data())[0]; + break; + } + kernel_name = "ArithmeticSub"; + break; + case PrimitiveType_Div: + if (is_bias_add_) { + weight_ = 1 / static_cast(inputs_[1]->Data())[0]; + break; + } + kernel_name = "ArithmeticDiv"; + break; + default: + MS_LOG(ERROR) << "Error Operator type " << opParameter->type_; + break; + } + +#ifdef PROGRAM_WITH_IL + ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); +#else + std::string program_name = "Arithmetic"; + std::set build_options; + std::string source = arithmetic_source_fp32; + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + return 0; +} + +int ArithmeticOpenCLKernel::Run() { + uint32_t element_num = outputs_[0]->ElementsC4Num(); + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + std::vector global = {element_num}; + std::vector local; + + ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data()); + if (is_bias_add_) { + MS_LOG(DEBUG) << "weight: " << weight_ << " bias: " << bias_; + ocl_runtime->SetKernelArg(kernel_, 1, outputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, 2, weight_); + ocl_runtime->SetKernelArg(kernel_, 3, bias_); + ocl_runtime->SetKernelArg(kernel_, 4, element_num / C4NUM); + } else { + ocl_runtime->SetKernelArg(kernel_, 1, inputs_[1]->Data()); + ocl_runtime->SetKernelArg(kernel_, 2, outputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, 3, element_num); + } + return ocl_runtime->RunKernel(kernel_, global, local, nullptr); +} + +kernel::LiteKernel *OpenCLArithmeticKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + auto *kernel = new ArithmeticOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create OpenCL Arithmetic kernel failed!"; + return nullptr; + } + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: Arithmetic"; + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, PrimitiveType_Mul, OpenCLArithmeticKernelCreator) +REG_KERNEL(kGPU, PrimitiveType_Add, OpenCLArithmeticKernelCreator) +REG_KERNEL(kGPU, PrimitiveType_Sub, OpenCLArithmeticKernelCreator) +REG_KERNEL(kGPU, PrimitiveType_Div, OpenCLArithmeticKernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h new file mode 100644 index 00000000000..df6c8438401 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h @@ -0,0 +1,45 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_ARITHMETIC_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_ARITHMETIC_H_ + +#include +#include "src/runtime/kernel/arm/fp32/arithmetic.h" +#include "src/runtime/opencl/opencl_runtime.h" + +namespace mindspore::kernel { + +class ArithmeticOpenCLKernel : public ArithmeticCPUKernel { + public: + explicit ArithmeticOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : ArithmeticCPUKernel(parameter, inputs, outputs, ctx) {} + ~ArithmeticOpenCLKernel() override {}; + + int Init() override; + int Run() override; + + private: + cl::Kernel kernel_; + bool is_bias_add_{false}; + float weight_{1.f}; + float bias_{.0f}; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_ARITHMETIC_H_ + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc new file mode 100644 index 00000000000..a3938841137 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc @@ -0,0 +1,136 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/opencl/kernel/concat.h" +#include +#include +#include "src/kernel_registry.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/arm/opclib/concat_parameter.h" +#ifndef PROGRAM_WITH_IL +#include "src/runtime/kernel/opencl/cl/fp32/concat.cl.inc" +#endif + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_Concat; + +namespace mindspore::kernel { + +int ConcatOpenCLKernel::Init() { + if (inputs_[0]->shape().size() != 4) { + MS_LOG(ERROR) << "only support dim=4"; + } + + auto param = reinterpret_cast(this->opParameter); + MS_LOG(INFO) << "concat at axis=: " << param->axis_; + if (param->axis_ != 0 && param->axis_ != 3) { + MS_LOG(ERROR) << "only support axis=0 or axis=3"; + } + + if (param->axis_ == 0) { + return 0; + } + + std::string kernel_name = "Concat"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); +#ifdef PROGRAM_WITH_IL + ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); +#else + std::set build_options; + std::string source = concat_source_fp32; + std::string program_name = "Concat"; + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + return 0; +} + +int ConcatOpenCLKernel::ReSize() { return 0; } + +int ConcatOpenCLKernel::Run_axis0() { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto allocator_ = ocl_runtime->GetAllocator(); + cl::CommandQueue *command_queue = ocl_runtime->GetDefaultCommandQueue(); + + for (auto &tensor : inputs_) { + auto buffer = static_cast(allocator_->GetDeviceBuffer(tensor->Data())); + ocl_runtime->MapBuffer(*buffer, CL_MAP_READ, tensor->Size(), command_queue, true); + } + for (auto &tensor : outputs_) { + auto buffer = static_cast(allocator_->GetDeviceBuffer(tensor->Data())); + ocl_runtime->MapBuffer(*buffer, CL_MAP_WRITE, tensor->Size(), command_queue, true); + } + + memcpy_s(outputs_[0]->Data(), inputs_[0]->Size(), inputs_[0]->Data(), inputs_[0]->Size()); + memcpy_s(reinterpret_cast(outputs_[0]->Data()) + inputs_[0]->Size(), inputs_[1]->Size(), inputs_[1]->Data(), + inputs_[1]->Size()); + + for (auto tensors : {&inputs_, &outputs_}) { + for (auto &tensor : *tensors) { + auto buffer = static_cast(allocator_->GetDeviceBuffer(tensor->Data())); + ocl_runtime->UnmapBuffer(*buffer, tensor->Data()); + } + } + return 0; +} + +int ConcatOpenCLKernel::Run() { + auto param = reinterpret_cast(this->opParameter); + if (param->axis_ == 0) { + return Run_axis0(); + } + + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + std::vector local = {1, 1, 1}; + std::vector global = {1, 1, 1}; + + auto input0_shape = inputs_[0]->shape(); + auto input1_shape = inputs_[1]->shape(); + auto output_shape = outputs_[0]->shape(); + cl_int4 input0_shape_ = {input0_shape[0], input0_shape[1], input0_shape[2], input0_shape[3]}; + cl_int4 input1_shape_ = {input1_shape[0], input1_shape[1], input1_shape[2], input1_shape[3]}; + cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], output_shape[3]}; + + int arg_cn = 0; + ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[1]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, input1_shape_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, param->axis_); + + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + return 0; +} + +kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + auto *kernel = new ConcatOpenCLKernel(opParameter, inputs, outputs); + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: Convolution"; + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, PrimitiveType_Concat, OpenCLConcatKernelCreator); +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h new file mode 100644 index 00000000000..7f16a791032 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h @@ -0,0 +1,57 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_Concat_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_Concat_H_ + +#include +#include +#include + +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/concat_parameter.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/arm/opclib/fp32/concat.h" +#include "src/runtime/kernel/arm/opclib/int8/concat_int8.h" + +namespace mindspore::kernel { + +class ConcatOpenCLKernel : public LiteKernel { + public: + explicit ConcatOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + + ~ConcatOpenCLKernel() override{}; + + int Init() override; + + int InferShape() { return {}; } + + int ReSize() override; + + int Run_axis0(); + + int Run() override; + + private: + cl::Kernel kernel_; +}; + +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_DEPTHWISE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc new file mode 100644 index 00000000000..ccf5fdfa8c7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc @@ -0,0 +1,180 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/kernel/conv2d_transpose.h" +#include +#include +#include "src/kernel_registry.h" +#include "src/runtime/opencl/opencl_runtime.h" +#ifndef PROGRAM_WITH_IL +#include "src/runtime/kernel/opencl/cl/fp16/conv2d_transpose2x2.cl.inc" +#include "src/runtime/kernel/opencl/cl/fp32/conv2d_transpose2x2.cl.inc" +#endif + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_DeConv2D; + +namespace mindspore::kernel { + +int Conv2dTransposeOpenCLKernel::Init() { + ConvParameter *param = reinterpret_cast(opParameter); + if (param->kernel_h_ != 2 || param->kernel_w_ != 2 || param->stride_h_ != 2 || param->stride_w_ != 2) { + MS_LOG(ERROR) << "only support kh=kw=2 and stride_h=stride_w=2."; + return 1; + } + if (param->pad_h_ >= 2 || param->pad_w_ >= 2) { + MS_LOG(ERROR) << "only support pad in {0,1}."; + return 1; + } + std::string kernel_name = "conv2d_transpose2x2"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); +#ifdef PROGRAM_WITH_IL + ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); +#else +#ifdef ENABLE_FP16 + std::string source = conv2d_transpose2x2_source_fp16; +#else + std::string source = conv2d_transpose2x2_source_fp32; +#endif + std::set build_options; + std::string program_name = "conv2d_transpose2x2"; + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + int ci = param->input_channel_; + int co = param->output_channel_; + int kh = param->kernel_h_; + int kw = param->kernel_w_; + int div_ci = UP_DIV(ci, 4); + int div_co = UP_DIV(co, 4); + auto allocator = ocl_runtime->GetAllocator(); + padWeight_ = reinterpret_cast(allocator->Malloc(div_ci * div_co * 16 * kh * kw * sizeof(FLOAT_T))); + padWeight_ = reinterpret_cast(allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true)); + bias_ = reinterpret_cast(allocator->Malloc(div_co * 4 * sizeof(FLOAT_T))); + bias_ = reinterpret_cast(allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true)); + PadWeight(); + allocator->UnmapBuffer(padWeight_); + allocator->UnmapBuffer(bias_); + return 0; +} + +int Conv2dTransposeOpenCLKernel::ReSize() { return 0; } + +void Conv2dTransposeOpenCLKernel::PadWeight() { + // OHWI to OIHW4(I)4(O) + ConvParameter *param = reinterpret_cast(opParameter); + int ci = param->input_channel_; + int co = param->output_channel_; + int kh = param->kernel_h_; + int kw = param->kernel_w_; + int div_ci = UP_DIV(ci, 4); + int div_co = UP_DIV(co, 4); + auto origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + auto origin_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + bool has_bias = origin_bias != nullptr; + int index = 0; + for (int co_i = 0; co_i < div_co; co_i++) { + for (int ci_i = 0; ci_i < div_ci; ci_i++) { + for (int kh_i = 0; kh_i < kh; kh_i++) { + for (int kw_i = 0; kw_i < kw; kw_i++) { + for (int ci4_i = 0; ci4_i < 4; ci4_i++) { + for (int co4_i = 0; co4_i < 4; co4_i++) { + int co_offset = co_i * 4 + co4_i; + int ci_offset = ci_i * 4 + ci4_i; + if (co_offset < co && ci_offset < ci) { + int ori_index = ((co_offset * kh + kh_i) * kw + kw_i) * ci + ci_offset; + padWeight_[index++] = origin_weight[ori_index]; + } else { + padWeight_[index++] = 0.; + } + } + } + } + } + } + } + for (int co_i = 0; co_i < div_co; co_i++) { + for (int co4_i = 0; co4_i < 4; co4_i++) { + int co_offset = co_i * 4 + co4_i; + if (has_bias && co_offset < co) { + bias_[co_offset] = origin_bias[co_offset]; + } else { + bias_[co_offset] = 0.; + } + } + } +} + +int Conv2dTransposeOpenCLKernel::Run() { + std::vector shapex = inputs_[0]->shape(); + int n = shapex[0]; + if (n > 1) { + MS_LOG(ERROR) << "Conv2dTranspose n > 1 not supported!"; + return 1; + } + ConvParameter *param = reinterpret_cast(opParameter); + int ci = param->input_channel_; + int co = param->output_channel_; + int kh = param->kernel_h_; + int kw = param->kernel_w_; + int pad = kh - 1 - param->pad_h_; + int oh = outputs_[0]->shape()[1]; + int ow = outputs_[0]->shape()[2]; + int h = inputs_[0]->shape()[1]; + int w = inputs_[0]->shape()[2]; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + // local size should less than MAX_GROUP_SIZE + std::vector local = {4, 4, 32}; + std::vector global = {UP_ROUND((size_t)oh / 2, local[0]), UP_ROUND((size_t)ow / 2, local[1]), + UP_ROUND((size_t)co / 4, local[2])}; + + cl_int2 kernel_size = {kh, kw}; + cl_int2 stride = {2, 2}; + cl_int2 padding = {pad, pad}; + cl_int4 src_size = {h, w, UP_DIV(ci, 4), 1}; + cl_int4 dst_size = {oh, ow, UP_DIV(co, 4), 1}; + ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, 1, padWeight_); + ocl_runtime->SetKernelArg(kernel_, 2, bias_); + ocl_runtime->SetKernelArg(kernel_, 3, outputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, 4, kernel_size); + ocl_runtime->SetKernelArg(kernel_, 5, stride); + ocl_runtime->SetKernelArg(kernel_, 6, padding); + ocl_runtime->SetKernelArg(kernel_, 7, src_size); + ocl_runtime->SetKernelArg(kernel_, 8, dst_size); + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + return 0; +} + +kernel::LiteKernel *OpenCLConv2dTransposeKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + auto *kernel = new Conv2dTransposeOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + auto ret = kernel->Init(); + if (0 != ret) { + // MS_LOG(ERROR) << "Init kernel failed, name: " << opDef.name()->str() + // << ", type: " << lite::EnumNameOpT(opDef.attr_type()); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, PrimitiveType_DeConv2D, OpenCLConv2dTransposeKernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h new file mode 100644 index 00000000000..992c96d8578 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h @@ -0,0 +1,56 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_CONV2D_TRANSPOSE_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_CONV2D_TRANSPOSE_H_ + +#include + +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/conv_parameter.h" +#include "src/runtime/opencl/opencl_runtime.h" + +#ifdef ENABLE_FP16 +using FLOAT_T = float16_t; +#else +using FLOAT_T = float; +#endif + +namespace mindspore::kernel { + +class Conv2dTransposeOpenCLKernel : public LiteKernel { + public: + explicit Conv2dTransposeOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~Conv2dTransposeOpenCLKernel() override {}; + + int Init() override; + int InferShape() {} + int ReSize() override; + int Run() override; + void PadWeight(); + + private: + ConvParameter *parameter_; + cl::Kernel kernel_; + FLOAT_T *padWeight_; + FLOAT_T *bias_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_CONV2D_TRANSPOSE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc new file mode 100644 index 00000000000..4fcbfd75da6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc @@ -0,0 +1,202 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/kernel/convolution.h" +#include +#include +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "src/runtime/opencl/opencl_runtime.h" +#ifndef PROGRAM_WITH_IL +#include "src/runtime/kernel/opencl/cl/fp32/convolution.cl.inc" +#endif + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_Conv2D; + +namespace mindspore::kernel { + +int ConvolutionOpenCLKernel::Init() { + MS_LOG(INFO) << "ConvolutionOpenCLKernel::Init()"; + + if (inputs_[0]->Batch() != 1 || outputs_[0]->Batch() != 1) { + MS_LOG(ERROR) << "ConvolutionOpenCLKernel only support Batch=1!"; + } + + auto io_NHWC = inputs_[0]->GetFormat() == schema::Format_NHWC && outputs_[0]->GetFormat() == schema::Format_NHWC; + auto io_NHWC4 = inputs_[0]->GetFormat() == schema::Format_NHWC4 && outputs_[0]->GetFormat() == schema::Format_NHWC4; + if (!io_NHWC && !io_NHWC4) { + MS_LOG(ERROR) << "input and output data_format is invalid!"; + } + io_dataformat_ = inputs_[0]->GetFormat(); + + if (inputs_[1]->GetFormat() != schema::Format_KHWC) { + MS_LOG(ERROR) << "weight data_format is invalid!"; + } + + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + std::string kernel_name = "convolution_NHWC_OHWI"; +#ifdef PROGRAM_WITH_IL + ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); +#else + std::set build_options; + std::string source = convolution_source_fp32; + std::string program_name = "convolution"; + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + + this->InitBuffer(); + return 0; +} +int ConvolutionOpenCLKernel::InitBuffer() { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto allocator = ocl_runtime->GetAllocator(); + + auto weight_tensor = inputs_[1]; + auto bias_tensor = inputs_[2]; + if (io_dataformat_ == schema::Format_NHWC) { + packed_weight_ = reinterpret_cast(allocator->Malloc(weight_tensor->Size())); + packed_weight_ = reinterpret_cast(allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true)); + memcpy_s(packed_weight_, weight_tensor->Size(), weight_tensor->Data(), weight_tensor->Size()); + allocator->UnmapBuffer(packed_weight_); + + packed_bias_ = reinterpret_cast(allocator->Malloc(bias_tensor->Size())); + packed_bias_ = reinterpret_cast(allocator->MapBuffer(packed_bias_, CL_MAP_WRITE, nullptr, true)); + memcpy_s(packed_bias_, bias_tensor->Size(), bias_tensor->Data(), bias_tensor->Size()); + allocator->UnmapBuffer(packed_bias_); + } else if (io_dataformat_ == schema::Format_NHWC4) { + auto weight_shape = weight_tensor->shape(); + size_t CO = weight_shape[0]; + size_t KH = weight_shape[1]; + size_t KW = weight_shape[2]; + size_t CI = weight_shape[3]; + size_t CI_ALIGN = UP_DIV(CI, C4NUM) * C4NUM; + size_t CO_ALIGN = UP_DIV(CO, C4NUM) * C4NUM; + size_t weight_size_tiled = CO_ALIGN * KH * KW * CI_ALIGN * sizeof(float); + + packed_weight_ = reinterpret_cast(allocator->Malloc(weight_size_tiled)); + packed_weight_ = reinterpret_cast(allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true)); + memset_s(packed_weight_, weight_size_tiled, 0x00, weight_size_tiled); + auto weight_data = reinterpret_cast(weight_tensor->Data()); + for (int co = 0; co < CO; ++co) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + for (int ci = 0; ci < CI; ++ci) { + packed_weight_[co * KH * KW * CI_ALIGN + kh * KW * CI_ALIGN + kw * CI_ALIGN + ci] = + weight_data[co * KH * KW * CI + kh * KW * CI + kw * CI + ci]; + } + } + } + } + allocator->UnmapBuffer(packed_weight_); + + size_t bias_size_tiled = CO_ALIGN * sizeof(float); + packed_bias_ = reinterpret_cast(allocator->Malloc(bias_size_tiled)); + packed_bias_ = reinterpret_cast(allocator->MapBuffer(packed_bias_, CL_MAP_WRITE, nullptr, true)); + memset_s(packed_bias_, bias_size_tiled, 0x00, bias_size_tiled); + auto bias_data = reinterpret_cast(bias_tensor->Data()); + for (int co = 0; co < CO; ++co) { + packed_bias_[co] = bias_data[co]; + } + allocator->UnmapBuffer(packed_bias_); + } + + return 0; +} + +int ConvolutionOpenCLKernel::ReSize() { return 0; } + +int ConvolutionOpenCLKernel::Run() { + MS_LOG(INFO) << "ConvolutionOpenCLKernel::Run()"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + + auto param = reinterpret_cast(opParameter); + auto input0_shape = inputs_[0]->shape(); // NHWC + auto input1_shape = inputs_[1]->shape(); // OHWI + auto outpu0_shape = outputs_[0]->shape(); // NHWC + cl_uint N = input0_shape[0]; + cl_uint CI = input0_shape[3]; + cl_uint IH = input0_shape[1]; + cl_uint IW = input0_shape[2]; + cl_uint CO = outpu0_shape[3]; + cl_uint OH = outpu0_shape[1]; + cl_uint OW = outpu0_shape[2]; + cl_uint KH = input1_shape[1]; + cl_uint KW = input1_shape[2]; + cl_uint CI_TILE_NUM = UP_DIV(CI, C4NUM); + cl_uint CO_TILE_NUM = UP_DIV(CO, C4NUM); + cl_uint CI_ALIGN = CI_TILE_NUM * C4NUM; + cl_uint CO_ALIGN = CO_TILE_NUM * C4NUM; + + cl_uint4 input_shape; + cl_uint4 weight_shape; + cl_uint4 output_shape; + if (io_dataformat_ == schema::Format_NHWC) { + input_shape = {N, IH, IW, CI}; + weight_shape = {CO, KH, KW, CI}; + output_shape = {N, OH, OW, CO}; + } else if (io_dataformat_ == schema::Format_NHWC4) { + input_shape = {N, IH, IW, CI_ALIGN}; + weight_shape = {CO_ALIGN, KH, KW, CI_ALIGN}; + output_shape = {N, OH, OW, CO_ALIGN}; + } + cl_uint2 stride = {static_cast(param->stride_h_), static_cast(param->stride_w_)}; + cl_uint4 pad = {static_cast(param->pad_u_), static_cast(param->pad_d_), + static_cast(param->pad_l_), static_cast(param->pad_r_)}; + + int arg_cn = 0; + ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, packed_weight_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, packed_bias_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, weight_shape); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, stride); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, pad); + + std::vector global = {OW, OH, CO_TILE_NUM}; + std::vector local = {1, 1, CO_TILE_NUM}; + + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + + return 0; +} + +kernel::LiteKernel *OpenCLConvolutionKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + auto *kernel = new ConvolutionOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create OpenCL Convolution kernel failed!"; + return nullptr; + } + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: Convolution"; + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, PrimitiveType_Conv2D, OpenCLConvolutionKernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.h new file mode 100644 index 00000000000..0ef0aa68713 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_CONVOLUTIONOPENCLKERNEL_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_CONVOLUTIONOPENCLKERNEL_H_ + +#include +#include "src/runtime/kernel/arm/fp32/convolution.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/arm/opclib/conv_parameter.h" + +namespace mindspore::kernel { + +class ConvolutionOpenCLKernel : public LiteKernel { + public: + explicit ConvolutionOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~ConvolutionOpenCLKernel() override{}; + + int Init() override; + int ReSize() override; + int Run() override; + int InitBuffer(); + + private: + schema::Format io_dataformat_ = schema::Format_NHWC4; + float *packed_weight_ = nullptr; + float *packed_bias_ = nullptr; + cl::Kernel kernel_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_CONVOLUTIONOPENCLKERNEL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc new file mode 100644 index 00000000000..e00820f25b7 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc @@ -0,0 +1,150 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/kernel/depthwise_conv2d.h" +#include +#include +#include "src/kernel_registry.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/arm/fp32/convolution_depthwise.h" +#include "src/runtime/kernel/arm/opclib/pack.h" +#ifndef PROGRAM_WITH_IL +#include "src/runtime/kernel/opencl/cl/fp16/depthwise_conv2d.cl.inc" +#include "src/runtime/kernel/opencl/cl/fp32/depthwise_conv2d.cl.inc" +#endif + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_DepthwiseConv2D; + + +namespace mindspore::kernel { + +int DepthwiseConv2dOpenCLKernel::Init() { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + std::string kernel_name = "DepthwiseConv2d_NHWC4"; + auto in_format = inputs_[0]->GetFormat(); + outputs_[0]->SetFormat(in_format); + if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4) { + MS_LOG(ERROR) << "input format(" << in_format << ") " << "format not support!"; + } + if (in_format == schema::Format_NC4HW4) { + kernel_name = "DepthwiseConv2d_NC4HW4"; + } +#ifdef PROGRAM_WITH_IL + ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); +#else + std::string program_name = "DepthwiseConv2d"; + std::set build_options; +#ifdef ENABLE_FP16 + std::string source = depthwise_conv2d_source_fp16; +#else + std::string source = depthwise_conv2d_source_fp32; +#endif + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + this->InitBuffer(); + return 0; +} +int DepthwiseConv2dOpenCLKernel::InitBuffer() { + auto parameter = reinterpret_cast(opParameter); + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto allocator = ocl_runtime->GetAllocator(); + + // weight: o, h, w, i; o == group, i == 1 + auto origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + int CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); + int pack_weight_size = C4NUM * CO4 * parameter->kernel_h_ * parameter->kernel_w_; + + packed_weight_ = reinterpret_cast(allocator->Malloc(pack_weight_size * sizeof(FLOAT_t))); + packed_weight_ = reinterpret_cast(allocator->MapBuffer(packed_weight_, CL_MAP_WRITE, nullptr, true)); + int plane = parameter->kernel_h_ * parameter->kernel_w_; +#ifdef ENABLE_FP16 + PackNCHWToNC4HW4Fp16(origin_weight, packed_weight_, 1, plane, outputs_[0]->Channel()); +#else + PackNCHWToNC4HW4Fp32(origin_weight, packed_weight_, 1, plane, outputs_[0]->Channel()); +#endif + + allocator->UnmapBuffer(packed_weight_); + + // init bias + if (inputs_.size() == kInputSize2) { + bias_data_ = reinterpret_cast(allocator->Malloc(C4NUM * CO4 * sizeof(FLOAT_t))); + bias_data_ = reinterpret_cast(allocator->MapBuffer(bias_data_, CL_MAP_WRITE, nullptr, true)); + size_t up_co_size = C4NUM * CO4 * sizeof(FLOAT_t); + memset_s(bias_data_, up_co_size, 0, up_co_size); + auto ori_bias = reinterpret_cast(inputs_.at(kBiasIndex)->Data()); + memcpy_s(bias_data_, outputs_[0]->Channel() * sizeof(FLOAT_t), ori_bias, outputs_[0]->Channel() * sizeof(FLOAT_t)); + allocator->UnmapBuffer(bias_data_); + } else { + MS_ASSERT(inputs_.size() == kInputSize1); + } + return 0; +} +int DepthwiseConv2dOpenCLKernel::ReSize() { + return 0; +} + +int DepthwiseConv2dOpenCLKernel::Run() { + auto parameter = reinterpret_cast(opParameter); + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); + size_t CI4 = UP_DIV(inputs_[0]->Channel(), C4NUM); + std::vector global = {(size_t)outputs_[0]->Width(), (size_t)outputs_[0]->Height(), CO4}; + std::vector local = {1, 1, 1}; + + float relu_clip1 = 6.0; + cl_int2 kernel_size = {parameter->kernel_h_, parameter->kernel_w_}; + cl_int2 stride = {parameter->stride_h_, parameter->stride_w_}; + cl_int2 padding = {-parameter->pad_h_, -parameter->pad_w_}; + cl_int2 dilation = {parameter->dilation_h_, parameter->dilation_w_}; + cl_int4 src_size = {inputs_[0]->Width(), inputs_[0]->Height(), (cl_int)CI4, inputs_[0]->Batch()}; + cl_int4 dst_size = {(cl_int)outputs_[0]->Width(), (cl_int)outputs_[0]->Height(), (cl_int)CO4, + (cl_int)outputs_[0]->Batch()}; + ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, 1, packed_weight_); + ocl_runtime->SetKernelArg(kernel_, 2, bias_data_); + ocl_runtime->SetKernelArg(kernel_, 3, relu_clip1); + ocl_runtime->SetKernelArg(kernel_, 4, outputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, 5, kernel_size); + ocl_runtime->SetKernelArg(kernel_, 6, stride); + ocl_runtime->SetKernelArg(kernel_, 7, padding); + ocl_runtime->SetKernelArg(kernel_, 8, dilation); + ocl_runtime->SetKernelArg(kernel_, 9, src_size); + ocl_runtime->SetKernelArg(kernel_, 10, dst_size); + + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + return 0; +} + +kernel::LiteKernel *OpenCLDepthwiseConv2dKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + auto *kernel = new DepthwiseConv2dOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init DepthwiseConv2dOpenCLKernel failed!"; + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, PrimitiveType_DepthwiseConv2D, OpenCLDepthwiseConv2dKernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.h new file mode 100644 index 00000000000..cfc06480574 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_DEPTHWISE_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_DEPTHWISE_H_ + +#include + +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/conv_parameter.h" +#include "src/runtime/opencl/opencl_runtime.h" + + +namespace mindspore::kernel { + +class DepthwiseConv2dOpenCLKernel : public LiteKernel { + public: + explicit DepthwiseConv2dOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs), + packed_weight_(nullptr), bias_data_(nullptr), kernel_(nullptr) {} + ~DepthwiseConv2dOpenCLKernel() override {}; + + int Init() override; + int ReSize() override; + int Run() override; + int InitBuffer(); + + private: + FLOAT_t *packed_weight_; + FLOAT_t *bias_data_; + cl::Kernel kernel_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_DEPTHWISE_H_ + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc new file mode 100644 index 00000000000..116f22175d5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc @@ -0,0 +1,151 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "src/kernel_registry.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/arm/opclib/fp32/matmul.h" +#include "src/runtime/kernel/opencl/kernel/matmul.h" +#ifndef PROGRAM_WITH_IL +#include "src/runtime/kernel/opencl/cl/fp16/matmul.cl.inc" +#include "src/runtime/kernel/opencl/cl/fp32/matmul.cl.inc" +#endif + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_FullConnection; +using mindspore::schema::PrimitiveType_MatMul; + +namespace mindspore::kernel { + +int MatMulOpenCLKernel::Init() { + std::string kernel_name = "MatMul"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + +#ifdef PROGRAM_WITH_IL + ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); +#else + std::set build_options; +// build_options.emplace("-DPOOL_AVG"); +#ifdef ENABLE_FP16 + std::string source = matmul_source_fp16; +#else + std::string source = matmul_source_fp32; +#endif + std::string program_name = "MatMul"; + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + int ci = inputs_[1]->shape()[1]; + int co = inputs_[1]->shape()[0]; + sizeCI = {ci, UP_DIV(ci, 4)}; + sizeCO = {co, UP_DIV(co, 4)}; + auto allocator = ocl_runtime->GetAllocator(); + padWeight_ = reinterpret_cast(allocator->Malloc(sizeCI.s[1] * sizeCO.s[1] * 16 * sizeof(FLOAT_T))); + padWeight_ = reinterpret_cast(allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true)); + if (hasBias_) { + bias_ = reinterpret_cast(allocator->Malloc(sizeCO.s[1] * 4 * sizeof(FLOAT_T))); + bias_ = reinterpret_cast(allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true)); + } + PadWeight(); + allocator->UnmapBuffer(padWeight_); + if (hasBias_) { + allocator->UnmapBuffer(bias_); + } + return 0; +} + +int MatMulOpenCLKernel::ReSize() { return 0; } + +void MatMulOpenCLKernel::PadWeight() { + auto origin_weight = reinterpret_cast(inputs_.at(kWeightIndex)->Data()); + int divCI = sizeCI.s[1]; + int divCO = sizeCO.s[1]; + int index = 0; + for (int i = 0; i < divCI; ++i) { + for (int j = 0; j < divCO; ++j) { + for (int k = 0; k < 4; ++k) { + for (int l = 0; l < 4; ++l) { + int src_x = i * 4 + l; + int src_y = j * 4 + k; + if (src_x < sizeCI.s[0] && src_y < sizeCO.s[0]) { + padWeight_[index++] = origin_weight[src_y * sizeCI.s[0] + src_x]; + } else { + padWeight_[index++] = 0; + } + } + } + } + } + if (hasBias_) { + memcpy(inputs_[2]->Data(), bias_, sizeof(FLOAT_T) * sizeCI.s[0]); + for (int i = sizeCI.s[0]; i < sizeCI.s[1] * 4; i++) { + bias_[i] = 0; + } + } +} + +int MatMulOpenCLKernel::Run() { + std::vector shapex = inputs_[0]->shape(); + int n = shapex[0]; + if (n > 1) { + MS_LOG(ERROR) << "MatMul n > 1 not supported!"; + return 1; + } + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + // local size should less than MAX_GROUP_SIZE + std::vector local = {64, 4}; + std::vector global = {UP_ROUND(sizeCO.s[1], local[0]), 4}; + + ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, 1, padWeight_); + ocl_runtime->SetKernelArg(kernel_, 2, outputs_[0]->Data()); + if (hasBias_) { + ocl_runtime->SetKernelArg(kernel_, 3, inputs_[2]->Data()); + } else { + ocl_runtime->SetKernelArg(kernel_, 3, nullptr); + } + ocl_runtime->SetKernelArg(kernel_, 4, sizeCI); + ocl_runtime->SetKernelArg(kernel_, 5, sizeCO); + ocl_runtime->SetKernelArg(kernel_, 6, hasBias_ ? 1 : 0); + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + return 0; +} + +kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + bool hasBias = false; + if (opParameter->type_ == PrimitiveType_FullConnection) { + hasBias = (reinterpret_cast(opParameter))->has_bias_; + } + auto *kernel = new MatMulOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs, hasBias); + auto ret = kernel->Init(); + if (0 != ret) { + // MS_LOG(ERROR) << "Init kernel failed, name: " << opDef.name()->str() + // << ", type: " << lite::EnumNameOpT(opDef.attr_type()); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, PrimitiveType_MatMul, OpenCLMatMulKernelCreator) +REG_KERNEL(kGPU, PrimitiveType_FullConnection, OpenCLMatMulKernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h new file mode 100644 index 00000000000..a1c4485dd9d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_MATMUL_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_MATMUL_H_ + +#include + +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/conv_parameter.h" +#include "src/runtime/opencl/opencl_runtime.h" + +#ifdef ENABLE_FP16 +using FLOAT_T = float16_t; +#else +using FLOAT_T = float; +#endif + +namespace mindspore::kernel { + +class MatMulOpenCLKernel : public LiteKernel { + public: + explicit MatMulOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, bool hasBias) + : LiteKernel(parameter, inputs, outputs) { + hasBias_ = hasBias; + } + ~MatMulOpenCLKernel() override{}; + + int Init() override; + int InferShape() {} + int ReSize() override; + int Run() override; + void PadWeight(); + + private: + cl::Kernel kernel_; + FLOAT_T *padWeight_; + FLOAT_T *bias_; + bool hasBias_ = false; + cl_int2 sizeCI; + cl_int2 sizeCO; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_MATMUL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc new file mode 100644 index 00000000000..a26fc9db4f6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc @@ -0,0 +1,141 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/kernel/pooling2d.h" +#include +#include +#include "include/errorcode.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/opencl/utils.h" +#include "src/runtime/opencl/opencl_wrapper.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/opencl/image_format.h" +#ifndef PROGRAM_WITH_IL +#include "src/runtime/kernel/opencl/cl/fp32/max_pool2d.cl.inc" +#include "src/runtime/kernel/opencl/cl/fp32/avg_pool2d.cl.inc" +#endif + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_INVALID_OP_NAME; +using mindspore::lite::RET_MEMORY_FAILED; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Pooling; + +namespace mindspore { +namespace kernel { +int PoolingOpenCLKernel::Init() { + std::string kernel_name; +#ifndef PROGRAM_WITH_IL + std::string source; + std::string program_name; +#endif + if (parameter_->max_pooling_) { + kernel_name = "MaxPooling2d"; +#ifndef PROGRAM_WITH_IL + source = max_pool2d_source_fp32; + program_name = "MaxPooling2d"; +#endif + } else if (parameter_->avg_pooling_) { + kernel_name = "AvgPooling2d"; +#ifndef PROGRAM_WITH_IL + source = avg_pool2d_source_fp32; + program_name = "AvgPooling2d"; +#endif + } else { + MS_LOG(ERROR) << "Init `Pooling2d` kernel failed!"; + return RET_INVALID_OP_NAME; + } + auto in_format = inputs_[0]->GetFormat(); + outputs_[0]->SetFormat(in_format); + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + +#ifdef PROGRAM_WITH_IL + ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); +#else + std::set build_options; + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + return RET_OK; +} + +std::vector PoolingOpenCLKernel::InitGlobalSize() const { + const size_t global_x = outputs_[0]->Height(); + const size_t global_y = outputs_[0]->Width(); + const size_t global_z = UP_ROUND_DIV(outputs_[0]->Channel(), 4); + std::vector global = {global_x, global_y, global_z}; + return global; +} + +int PoolingOpenCLKernel::InitBuffer() { return 0; } +int PoolingOpenCLKernel::ReSize() { return 0; } + +int PoolingOpenCLKernel::Run() { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + + // attribute + int slices = UP_ROUND_DIV(outputs_[0]->Channel(), 4); + cl_int4 input_shape = {inputs_[0]->Height(), inputs_[0]->Width(), inputs_[0]->Channel(), slices}; + cl_int4 output_shape = {outputs_[0]->Height(), outputs_[0]->Width(), outputs_[0]->Channel(), slices}; + cl_int2 stride = {parameter_->stride_h_, parameter_->stride_w_}; + cl_int2 kernel_size = {parameter_->window_h_, parameter_->window_w_}; + cl_int2 padding = {parameter_->pad_u_, parameter_->pad_l_}; + + // binding parameters + int arg_idx = 0; + ocl_runtime->SetKernelArg(kernel_, arg_idx++, inputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, outputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, output_shape); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, stride); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, kernel_size); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, padding); + + // set work group size + std::vector local_size; + std::vector global_size = InitGlobalSize(); + int max_work_group_size = ocl_runtime->GetKernelMaxWorkGroupSize(kernel_(), (*ocl_runtime->Device())()); + local_size = GetLocalSize(global_size, max_work_group_size); + global_size = GetGlobalSize(local_size, global_size); + + // run opengl kernel + ocl_runtime->RunKernel(kernel_, global_size, local_size, nullptr); + + return RET_OK; +} + +kernel::LiteKernel *OpenCLPooling2dKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + auto *kernel = new PoolingOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create OpenCL Pooling kernel failed!"; + return nullptr; + } + auto ret = kernel->Init(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "Init OpenCL Pooling kernel failed!"; + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, PrimitiveType_Pooling, OpenCLPooling2dKernelCreator) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.h new file mode 100644 index 00000000000..a39b43d187d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_POOLING_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_POOLING_H_ + +#include + +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/fp32/pooling.h" +#include "src/runtime/opencl/opencl_runtime.h" + +namespace mindspore::kernel { + +class PoolingOpenCLKernel : public LiteKernel { + public: + explicit PoolingOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) { + parameter_ = reinterpret_cast(parameter); + } + ~PoolingOpenCLKernel() override{}; + + int Init() override; + int ReSize() override; + int Run() override; + int InitBuffer(); + + private: + std::vector InitGlobalSize() const; + + PoolingParameter *parameter_; + cl::Kernel kernel_; +}; + +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_POOLING_H_ + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc new file mode 100644 index 00000000000..4db2edc6123 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/kernel/softmax.h" +#include +#include +#include "src/kernel_registry.h" +#include "src/runtime/opencl/opencl_runtime.h" +#ifndef PROGRAM_WITH_IL +#include "src/runtime/kernel/opencl/cl/fp32/softmax.cl.inc" +#endif + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_SoftMax; + +namespace mindspore { +namespace kernel { +int SoftmaxOpenCLKernel::Init() { + std::string kernel_name = "SoftMax"; + if (parameter_->axis_ != -1 && parameter_->axis_ != 3) { + MS_LOG(ERROR) << "Init `Softmax` kernel failed: Unsupported axis: " << parameter_->axis_; + return -1; + } + + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); +#ifdef PROGRAM_WITH_IL + ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); +#else + std::set build_options; + std::string source = softmax_source_fp32; + std::string program_name = "SoftMax"; + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + return 0; +} + +int SoftmaxOpenCLKernel::InitBuffer() { return 0; } +int SoftmaxOpenCLKernel::ReSize() { return 0; } + +int SoftmaxOpenCLKernel::Run() { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto allocator = ocl_runtime->GetAllocator(); + + // global and local workers + const uint32_t grid_x = inputs_[0]->shape()[2]; // W + const uint32_t grid_y = inputs_[0]->shape()[1]; // H + const uint32_t grid_z = 1; + std::vector global = {grid_x, grid_y, grid_z}; + std::vector local = {1, 1, 1}; + + // input and output + cl::Buffer *input = reinterpret_cast(allocator->GetDeviceBuffer(inputs_[0]->Data())); + cl::Buffer *output = reinterpret_cast(allocator->GetDeviceBuffer(outputs_[0]->Data())); + cl_int4 input_size = {inputs_[0]->shape()[0], inputs_[0]->shape()[1], inputs_[0]->shape()[2], inputs_[0]->shape()[3]}; + int arg_idx = 0; + ocl_runtime->SetKernelArg(kernel_, arg_idx++, *input); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, *output); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_size); + + // run opengl kernel + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + + return 0; +} + +kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + auto *kernel = new SoftmaxOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + if (inputs[0]->shape()[0] > 1) { + MS_LOG(ERROR) << "Init `Softmax` kernel failed: Unsupported multi-batch."; + } + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init `Softmax` kernel failed!"; + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, PrimitiveType_SoftMax, OpenCLSoftMaxKernelCreator) +} // namespace kernel +} // namespace mindspore + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h new file mode 100644 index 00000000000..393fb846bd2 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_SOFTMAX_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_SOFTMAX_H_ + +#include + +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/fp32/softmax.h" +#include "src/runtime/opencl/opencl_runtime.h" + +namespace mindspore { +namespace kernel { +class SoftmaxOpenCLKernel : public LiteKernel { + public: + explicit SoftmaxOpenCLKernel(OpParameter *parameter, + const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) { + parameter_ = reinterpret_cast(parameter); + } + ~SoftmaxOpenCLKernel() override{}; + + int Init() override; + int ReSize() override; + int Run() override; + int InitBuffer(); + + private: + SoftmaxParameter *parameter_; + cl::Kernel kernel_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_SOFTMAX_H_ + diff --git a/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc new file mode 100644 index 00000000000..dc055112788 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "src/runtime/opencl/opencl_executor.h" +#include "src/runtime/opencl/opencl_runtime.h" + +namespace mindspore::kernel { + +SubGraphOpenCLKernel::~SubGraphOpenCLKernel() { UnInit(); } + +int SubGraphOpenCLKernel::Init() { + allocator_ = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator(); + for (const auto tensor : inputs_) { + tensor->MallocData(allocator_); + } + for (const auto tensor : outputs_) { + tensor->MallocData(allocator_); + } + // Map buffer for write, it is not necessary for fine-grained + for (auto &tensor : inputs_) { + void *data = allocator_->MapBuffer(tensor->Data(), CL_MAP_WRITE, nullptr, true); + // It is required with coarse-grained SVM + if (data != nullptr) { + tensor->SetData(data); + } else { + MS_LOG(ERROR) << "OpenCL kernel must use GPU buffer pointer, " + << "please make sure that this buffer allocate by OpenCLAllocator!"; + } + } + return 0; +} + +int SubGraphOpenCLKernel::UnInit() { + for (auto &tensor : outputs_) { + allocator_->UnmapBuffer(tensor->Data()); + } + for (const auto tensor : inputs_) { + if (tensor != nullptr) { + tensor->FreeData(allocator_); + } + } + for (const auto tensor : outputs_) { + if (tensor != nullptr) { + tensor->FreeData(allocator_); + } + } + return 0; +} + +int SubGraphOpenCLKernel::InferShape() { return 0; } + +int SubGraphOpenCLKernel::ReSize() { return 0; } + +int SubGraphOpenCLKernel::Run() { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + for (auto &tensor : inputs_) { + allocator_->UnmapBuffer(tensor->Data()); + } + + lite::opencl::OpenCLExecutor executor; + executor.Run(inputs_, outputs_, nodes_, allocator_); + ocl_runtime->SyncCommandQueue(); + for (auto &tensor : outputs_) { + void *data = allocator_->MapBuffer(tensor->Data(), CL_MAP_READ, nullptr, true); + tensor->SetData(data); + } + return 0; +} + +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h new file mode 100644 index 00000000000..7786067cb21 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h @@ -0,0 +1,55 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_SUBGRAPH_OPENCL_KENEL_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_SUBGRAPH_OPENCL_KENEL_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/opencl/opencl_allocator.h" + +namespace mindspore::kernel { + +struct SubGraphOpenCLParameter { + OpParameter op_parameter; + int input_size; + int output_size; +}; + +class SubGraphOpenCLKernel : public SubGraphKernel { + public: + explicit SubGraphOpenCLKernel(const std::vector inputs, + const std::vector outputs, + const std::vector inKernels, + const std::vector outKernels, + const std::vector nodes) + : SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes) {} + ~SubGraphOpenCLKernel() override; + + int Init() override; + int InferShape() override; + int ReSize() override; + int Run() override; + int UnInit(); + + private: + SubGraphOpenCLParameter *subgraph_ocl_parameter_; + lite::opencl::OpenCLAllocator *allocator_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_SUBGRAPH_OPENCL_KERNEL_H_ + diff --git a/mindspore/lite/src/runtime/kernel/opencl/utils.cc b/mindspore/lite/src/runtime/kernel/opencl/utils.cc new file mode 100644 index 00000000000..3940f0db886 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/utils.cc @@ -0,0 +1,174 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/utils.h" +#include +#include +#include + +namespace mindspore { +namespace kernel { + +std::vector GetGlobalSize(const std::vector &local, const std::vector &global) { + std::vector result(3, 1); + for (int i = 0; i < 3; ++i) { + result[i] = AlignByN(global[i], local[i]); + } + return result; +} + +std::vector GetLocalSize(const std::vector &global, int max_size) { + size_t wg_z = GetBiggestDividerWithPriority(global[2], 8); + size_t wg_xy_size = max_size / wg_z; + size_t wg_x = std::min(DivideRoundUp(global[0], 2), wg_xy_size); + size_t wg_y = std::min(wg_xy_size / wg_x, global[1]); + std::vector local = {wg_x, wg_y, wg_z}; + return local; +} + +std::string CLErrorCode(cl_int error_code) { + switch (error_code) { + case CL_SUCCESS: + return "Success"; + case CL_DEVICE_NOT_FOUND: + return "Device not found"; + case CL_DEVICE_NOT_AVAILABLE: + return "Device not available"; + case CL_COMPILER_NOT_AVAILABLE: + return "Compiler not available"; + case CL_MEM_OBJECT_ALLOCATION_FAILURE: + return "Memory object allocation failure"; + case CL_OUT_OF_RESOURCES: + return "Out of resources"; + case CL_OUT_OF_HOST_MEMORY: + return "Out of host memory"; + case CL_PROFILING_INFO_NOT_AVAILABLE: + return "Profiling information not available"; + case CL_MEM_COPY_OVERLAP: + return "Memory copy overlap"; + case CL_IMAGE_FORMAT_MISMATCH: + return "Image format mismatch"; + case CL_IMAGE_FORMAT_NOT_SUPPORTED: + return "Image format not supported"; + case CL_BUILD_PROGRAM_FAILURE: + return "Build program failure"; + case CL_MAP_FAILURE: + return "Mapping failure"; + case CL_MISALIGNED_SUB_BUFFER_OFFSET: + return "Misaligned sub-buffer offset"; + case CL_EXEC_STATUS_ERROR_FOR_EVENTS_IN_WAIT_LIST: + return "Execution status error for events in wait list"; + case CL_COMPILE_PROGRAM_FAILURE: + return "Compile program failure"; + case CL_LINKER_NOT_AVAILABLE: + return "Linker not available"; + case CL_LINK_PROGRAM_FAILURE: + return "Link program failure"; + case CL_DEVICE_PARTITION_FAILED: + return "Device partition failed"; + case CL_KERNEL_ARG_INFO_NOT_AVAILABLE: + return "Kernel argument information not available"; + case CL_INVALID_VALUE: + return "Invalid value"; + case CL_INVALID_DEVICE_TYPE: + return "Invalid device type"; + case CL_INVALID_PLATFORM: + return "Invalid platform"; + case CL_INVALID_DEVICE: + return "Invalid device"; + case CL_INVALID_CONTEXT: + return "Invalid context"; + case CL_INVALID_QUEUE_PROPERTIES: + return "Invalid queue properties"; + case CL_INVALID_COMMAND_QUEUE: + return "Invalid command queue"; + case CL_INVALID_HOST_PTR: + return "Invalid host pointer"; + case CL_INVALID_MEM_OBJECT: + return "Invalid memory object"; + case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR: + return "Invalid image format descriptor"; + case CL_INVALID_IMAGE_SIZE: + return "Invalid image size"; + case CL_INVALID_SAMPLER: + return "Invalid sampler"; + case CL_INVALID_BINARY: + return "Invalid binary"; + case CL_INVALID_BUILD_OPTIONS: + return "Invalid build options"; + case CL_INVALID_PROGRAM: + return "Invalid program"; + case CL_INVALID_PROGRAM_EXECUTABLE: + return "Invalid program executable"; + case CL_INVALID_KERNEL_NAME: + return "Invalid kernel name"; + case CL_INVALID_KERNEL_DEFINITION: + return "Invalid kernel definition"; + case CL_INVALID_KERNEL: + return "Invalid kernel"; + case CL_INVALID_ARG_INDEX: + return "Invalid argument index"; + case CL_INVALID_ARG_VALUE: + return "Invalid argument value"; + case CL_INVALID_ARG_SIZE: + return "Invalid argument size"; + case CL_INVALID_KERNEL_ARGS: + return "Invalid kernel arguments"; + case CL_INVALID_WORK_DIMENSION: + return "Invalid work dimension"; + case CL_INVALID_WORK_GROUP_SIZE: + return "Invalid work group size"; + case CL_INVALID_WORK_ITEM_SIZE: + return "Invalid work item size"; + case CL_INVALID_GLOBAL_OFFSET: + return "Invalid global offset"; + case CL_INVALID_EVENT_WAIT_LIST: + return "Invalid event wait list"; + case CL_INVALID_EVENT: + return "Invalid event"; + case CL_INVALID_OPERATION: + return "Invalid operation"; + case CL_INVALID_GL_OBJECT: + return "Invalid GL object"; + case CL_INVALID_BUFFER_SIZE: + return "Invalid buffer size"; + case CL_INVALID_MIP_LEVEL: + return "Invalid mip-level"; + case CL_INVALID_GLOBAL_WORK_SIZE: + return "Invalid global work size"; + case CL_INVALID_PROPERTY: + return "Invalid property"; + case CL_INVALID_IMAGE_DESCRIPTOR: + return "Invalid image descriptor"; + case CL_INVALID_COMPILER_OPTIONS: + return "Invalid compiler options"; + case CL_INVALID_LINKER_OPTIONS: + return "Invalid linker options"; + case CL_INVALID_DEVICE_PARTITION_COUNT: + return "Invalid device partition count"; + case CL_INVALID_PIPE_SIZE: + return "Invalid pipe size"; + case CL_INVALID_DEVICE_QUEUE: + return "Invalid device queue"; + case CL_INVALID_GL_SHAREGROUP_REFERENCE_KHR: + return "Invalid GL share group reference KHR"; + default: + return "Unknown OpenCL error code"; + } +} +} // namespace kernel +} // namespace mindspore + diff --git a/mindspore/lite/src/runtime/kernel/opencl/utils.h b/mindspore/lite/src/runtime/kernel/opencl/utils.h new file mode 100644 index 00000000000..23a3b177a11 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/utils.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_UTILS_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_UTILS_H_ + +#include +#include +#include "CL/cl2.hpp" +#include "utils/log_adapter.h" + +namespace mindspore::kernel { + +/** + * GetLocalSize + * @param number + * @param max_divider + * @return + */ +template +T GetBiggestDividerWithPriority(T number, N max_divider) { + if (number % 8 == 0 && 8 <= max_divider) { + return (T)8; + } + if (number % 4 == 0 && 4 <= max_divider) { + return (T)4; + } + if (number % 2 == 0 && 2 <= max_divider) { + return (T)2; + } + for (int i = max_divider; i != 0; i--) { + if (number % i == 0) { + return (T)i; + } + } + return (T)1; +} + +/** + * GetLocalSize + * @param n must be non negative + * @param divisor must be greater than zero + * @return + */ +template +T DivideRoundUp(T n, N divisor) { + const T div = static_cast(divisor); + const T q = n / div; + return n % div == 0 ? q : q + 1; +} + +/** + * GetLocalSize + * @param number + * @param n + * @return + */ +template +T AlignByN(T number, N n) { + return DivideRoundUp(number, n) * n; +} + +// GetGlobalSize +std::vector GetGlobalSize(const std::vector &local, const std::vector &global); + +// GetLocalSize +std::vector GetLocalSize(const std::vector &global, int max_size); + +std::string CLErrorCode(cl_int error_code); + + +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_UTILS_H_ + diff --git a/mindspore/lite/src/runtime/opencl/CMakeLists.txt b/mindspore/lite/src/runtime/opencl/CMakeLists.txt new file mode 100644 index 00000000000..5f5e73f8677 --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/CMakeLists.txt @@ -0,0 +1,11 @@ +set(OPENCL_RUNTIME_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/opencl_allocator.cc + ${CMAKE_CURRENT_SOURCE_DIR}/opencl_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/opencl_allocator.h + ${CMAKE_CURRENT_SOURCE_DIR}/opencl_kernel.h + ${CMAKE_CURRENT_SOURCE_DIR}/opencl_runtime.cc + ${CMAKE_CURRENT_SOURCE_DIR}/opencl_runtime.h + ${CMAKE_CURRENT_SOURCE_DIR}/opencl_wrapper.cc + ${CMAKE_CURRENT_SOURCE_DIR}/opencl_wrapper.h + + ) diff --git a/mindspore/lite/src/runtime/opencl/opencl_allocator.cc b/mindspore/lite/src/runtime/opencl/opencl_allocator.cc new file mode 100644 index 00000000000..0e2f595330b --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/opencl_allocator.cc @@ -0,0 +1,212 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/opencl/opencl_allocator.h" +#include +#include "utils/log_adapter.h" +#include "src/runtime/opencl/opencl_runtime.h" + +namespace mindspore::lite::opencl { + +OpenCLAllocator::OpenCLAllocator() {} +OpenCLAllocator::~OpenCLAllocator() {} + +void OpenCLAllocator::SetContext(const AllocatorContext &ctx) { + lock_flag_ = ctx.lockFlag; + shift_factor_ = ctx.shiftFactor; +} + +void OpenCLAllocator::Lock() { + if (lock_flag_) { + lock.lock(); + } +} + +void OpenCLAllocator::UnLock() { + if (lock_flag_) { + lock.unlock(); + } +} + +void *OpenCLAllocator::Malloc(size_t size) { + if (size > MAX_MALLOC_SIZE) { + MS_LOG(ERROR) << "MallocData out of max_size, size: " << size; + return nullptr; + } + Lock(); + auto iter = free_list_.lower_bound(size); + if (iter != free_list_.end() && (iter->second->size_ >= size) && (iter->second->size_ < (size << shift_factor_))) { + auto mem_buf = iter->second; + free_list_.erase(iter); + allocated_list_[mem_buf->host_ptr_] = mem_buf; + UnLock(); + MS_LOG(DEBUG) << "Malloc buffer from free list. size: " << mem_buf->size_ << ", host addr: " << mem_buf->host_ptr_ + << ", device addr: " << mem_buf->device_ptr_; + return mem_buf->host_ptr_; + } + auto ocl_runtime = opencl::OpenCLRuntime::GetInstance(); + auto svm_capabilities = ocl_runtime->GetSVMCapabilities(); + void *host_ptr = nullptr; + void *device_ptr = nullptr; + if (svm_capabilities) { + cl_svm_mem_flags flags = (svm_capabilities & CL_DEVICE_SVM_FINE_GRAIN_BUFFER) ? CL_MEM_SVM_FINE_GRAIN_BUFFER : 0; + flags |= (svm_capabilities & CL_DEVICE_SVM_ATOMICS) ? CL_MEM_SVM_ATOMICS : 0; + flags = flags | CL_MEM_READ_WRITE; + host_ptr = clSVMAlloc((*ocl_runtime->Context())(), flags, size, 0); + } else { + cl_int ret = CL_SUCCESS; + cl::Buffer *buffer = + new cl::Buffer(*ocl_runtime->Context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, size, NULL, &ret); + if (ret != CL_SUCCESS) { + MS_LOG(ERROR) << "Create OpenCL buffer failed! (ERROR CODE: " << ret << ")"; + UnLock(); + return nullptr; + } + device_ptr = static_cast(buffer); + host_ptr = ocl_runtime->MapBuffer(*buffer, CL_MAP_READ | CL_MAP_WRITE, size); + ocl_runtime->UnmapBuffer(*buffer, host_ptr); + } + std::unique_ptr mem_buf = std::make_unique(); + mem_buf->size_ = size; + mem_buf->device_ptr_ = device_ptr; + mem_buf->host_ptr_ = host_ptr; + MS_LOG(DEBUG) << "Malloc a new buffer. size: " << mem_buf->size_ << ", host addr: " << mem_buf->host_ptr_ + << ", device addr: " << mem_buf->device_ptr_; + allocated_list_[host_ptr] = mem_buf.release(); + UnLock(); + return host_ptr; +} + +void OpenCLAllocator::Free(void *buf) { + if (buf == nullptr) { + return; + } + Lock(); + auto iter = allocated_list_.find(buf); + if (iter != allocated_list_.end()) { + auto mem_buf = iter->second; + allocated_list_.erase(iter); + free_list_.insert(std::make_pair(mem_buf->size_, mem_buf)); + UnLock(); + return; + } + UnLock(); + free(buf); +} + +size_t OpenCLAllocator::GetTotalSize() { + Lock(); + size_t totalSize = 0; + + for (auto it = allocated_list_.begin(); it != allocated_list_.end(); it++) { + totalSize += it->second->size_; + } + + for (auto it = free_list_.begin(); it != free_list_.end(); it++) { + totalSize += it->second->size_; + } + UnLock(); + return totalSize; +} + +void *OpenCLAllocator::GetDeviceBuffer(void *buffer) { + auto it = allocated_list_.find(buffer); + if (it != allocated_list_.end()) { + return it->second->device_ptr_; + } + return nullptr; +} + +void OpenCLAllocator::Clear() { + Lock(); + auto ocl_runtime = opencl::OpenCLRuntime::GetInstance(); + auto svm_capabilities = ocl_runtime->GetSVMCapabilities(); + for (auto it = allocated_list_.begin(); it != allocated_list_.end(); it++) { + if (svm_capabilities) { + clSVMFree((*ocl_runtime->Context())(), it->second->host_ptr_); + MS_LOG(DEBUG) << "OpenCL free svm buffer : " << it->second->host_ptr_; + } else { + cl::Buffer *buff = static_cast(it->second->device_ptr_); + MS_LOG(DEBUG) << "OpenCL free device buffer : " << buff; + delete buff; + } + } + allocated_list_.clear(); + + for (auto it = free_list_.begin(); it != free_list_.end(); it++) { + if (svm_capabilities) { + clSVMFree((*ocl_runtime->Context())(), it->second->host_ptr_); + MS_LOG(DEBUG) << "OpenCL free svm buffer : " << it->second->host_ptr_; + } else { + cl::Buffer *buff = static_cast(it->second->device_ptr_); + MS_LOG(DEBUG) << "OpenCL free device buffer : " << buff; + delete buff; + } + } + free_list_.clear(); + UnLock(); +} + +void *OpenCLAllocator::MapBuffer(void *host_ptr, int flags, void *command_queue, bool sync) { + auto ocl_runtime = opencl::OpenCLRuntime::GetInstance(); + auto svm_capabilities = ocl_runtime->GetSVMCapabilities(); + if (svm_capabilities) { + if (!(svm_capabilities & CL_DEVICE_SVM_FINE_GRAIN_BUFFER)) { + auto it = allocated_list_.find(host_ptr); + if (it == allocated_list_.end()) { + MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << host_ptr; + return nullptr; + } + ocl_runtime->MapBuffer(host_ptr, flags, it->second->size_, static_cast(command_queue), sync); + } + return host_ptr; + } + Lock(); + auto it = allocated_list_.find(host_ptr); + if (it == allocated_list_.end()) { + MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << host_ptr; + return nullptr; + } + MemBuf *mem_buf = it->second; + cl::Buffer *buffer = static_cast(mem_buf->device_ptr_); + void *new_host_ptr = ocl_runtime->MapBuffer(*buffer, flags, mem_buf->size_, nullptr, sync); + mem_buf->host_ptr_ = new_host_ptr; + allocated_list_.erase(it); + allocated_list_[new_host_ptr] = mem_buf; + UnLock(); + return new_host_ptr; +} + +int OpenCLAllocator::UnmapBuffer(void *host_ptr, void *command_queue) { + auto ocl_runtime = opencl::OpenCLRuntime::GetInstance(); + auto svm_capabilities = ocl_runtime->GetSVMCapabilities(); + if (svm_capabilities) { + if (!(svm_capabilities & CL_DEVICE_SVM_FINE_GRAIN_BUFFER)) { + return ocl_runtime->UnmapBuffer(host_ptr); + } + return 0; + } + auto it = allocated_list_.find(host_ptr); + if (it == allocated_list_.end()) { + MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << host_ptr; + return 1; + } + cl::Buffer *buffer = static_cast(it->second->device_ptr_); + return ocl_runtime->UnmapBuffer(*buffer, it->second->host_ptr_, static_cast(command_queue)); +} + +} // namespace mindspore::lite::opencl + diff --git a/mindspore/lite/src/runtime/opencl/opencl_allocator.h b/mindspore/lite/src/runtime/opencl/opencl_allocator.h new file mode 100644 index 00000000000..e8f6578347f --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/opencl_allocator.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPENCL_ALLOCATOR_H_ +#define MINDSPORE_LITE_SRC_OPENCL_ALLOCATOR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "src/runtime/allocator.h" + +namespace mindspore::lite::opencl { + +#define MS_HOST_BUFFER 0 +#define MS_CL_BUFFER (1 << 1) +#define MS_CL_IMAGE2D (1 << 2) +typedef int32_t OpenCLMemoryType; + +struct OpenclMemory { + void *host_ptr{nullptr}; + void *device_ptr{nullptr}; + OpenCLMemoryType mem_type{MS_HOST_BUFFER | MS_CL_BUFFER}; +}; + +class OpenCLAllocator : public Allocator { + public: + OpenCLAllocator(); + ~OpenCLAllocator() override; + void SetContext(const AllocatorContext &ctx) override; + void *Malloc(size_t size) override; + void Free(void *ptr) override; + size_t GetTotalSize() override; + void Clear() override; + void *GetDeviceBuffer(void *buffer); + void *MapBuffer(void *host_ptr, int flags, void *command_queue = nullptr, bool sync = true); + int UnmapBuffer(void *host_ptr, void *command_queue = nullptr); + + private: + void Lock(); + void UnLock(); + struct MemBuf { + size_t size_; + void *device_ptr_; + void *host_ptr_; + }; + + std::mutex lock; + // buf, membuf> + std::unordered_map allocated_list_; + std::multimap free_list_; + // 6 is empirical value + int shift_factor_ = 6; + bool lock_flag_ = false; +}; + +} // namespace mindspore::lite::opencl + +#endif // MINDSPORE_LITE_SRC_OPENCL_ALLOCATOR_H_ + diff --git a/mindspore/lite/src/runtime/opencl/opencl_executor.cc b/mindspore/lite/src/runtime/opencl/opencl_executor.cc new file mode 100644 index 00000000000..bd6d58b10c0 --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/opencl_executor.cc @@ -0,0 +1,140 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/opencl/opencl_executor.h" +#include "src/runtime/kernel/arm/opclib/pack.h" +#include "include/errorcode.h" + +namespace mindspore::lite::opencl { +int OpenCLExecutor::Run(std::vector &inputs, std::vector &outputs, + std::vector &kernels, Allocator *allocator, + const kernel::KernelCallBack &before, const kernel::KernelCallBack &after) { + MS_ASSERT(nullptr != allocator); + for (auto &inTensor : inputs) { + if (inTensor == nullptr) { + MS_LOG(ERROR) << "Graph input tensor is nullptr"; + return RET_ERROR; + } + if (inTensor->GetFormat() != schema::Format_NHWC4 && inTensor->GetFormat() != schema::Format_NC4HW4) { + if (inTensor->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "Model input should be NHWC, actual is " << schema::EnumNameFormat(inTensor->GetFormat()); + return RET_ERROR; + } else { + TransformTensorLayout(inTensor, schema::Format_NHWC4); + // TransformTensorLayout(inTensor, schema::Format_NC4HW4); + } + } + } + kernel::LiteKernelUtil::InitTensorRefCount(kernels); + for (auto *kernel : kernels) { + MS_ASSERT(nullptr != kernel); + auto &outputs = kernel->GetOutputs(); + for (auto *output : outputs) { + MS_ASSERT(nullptr != output); + output->MallocData(allocator_); + } + kernel::CallBackParam callbackParam; + callbackParam.name_callback_aram = kernel->Name(); + + if (before != nullptr) { + if (!before(kernel->GetInputs(), kernel->GetOutputs(), callbackParam)) { + MS_LOG(ERROR) << "run kernel before_callback failed, name: " << kernel->Name(); + } + } + auto ret = kernel->Run(); + if (0 != ret) { + MS_LOG(ERROR) << "run kernel failed, name: " << kernel->Name(); + return ret; + } + + if (after != nullptr) { + if (!after(kernel->GetInputs(), kernel->GetOutputs(), callbackParam)) { + MS_LOG(ERROR) << "run kernel after_callback failed, name: " << kernel->Name(); + } + } + for (auto input_kernel : kernel->GetInKernels()) { + MS_EXCEPTION_IF_NULL(input_kernel); + ret = input_kernel->DecOutTensorRefCount(allocator_); + if (0 != ret) { + MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->Name() << " failed"; + } + } + } + // output format transform + for (auto &outTensor : outputs) { + if (outTensor == nullptr) { + MS_LOG(ERROR) << "Graph output tensor is nullptr"; + return RET_ERROR; + } + if (outTensor->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "Model output tensor should be NHWC"; + } + } + return RET_OK; +} + +int OpenCLExecutor::TransformTensorLayout(tensor::Tensor *tensor, schema::Format dst_format) { + MS_ASSERT(nullptr != tensor); + MS_ASSERT(4 == tensor->shape().size()); + auto data_type = tensor->data_type(); + switch (data_type) { + case kNumberTypeInt8: + return TransformTensorLayoutUint8(tensor, dst_format); + case kNumberTypeFloat32: + return TransformTensorLayoutFp32(tensor, dst_format); + default: + MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + << schema::EnumNameFormat(dst_format); + return RET_ERROR; + } + return RET_OK; +} + +int OpenCLExecutor::TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format dst_format) { + MS_ASSERT(nullptr != tensor); + MS_ASSERT(nullptr != allocator_); + MS_ASSERT(4 == tensor->shape().size()); + if (dst_format == schema::Format_NHWC4) { + auto *src_data = tensor->Data(); + auto *dst_data = allocator_->Malloc(tensor->Size()); + if (dst_data == nullptr) { + MS_LOG(ERROR) << "Malloc data failed"; + return RET_ERROR; + } + dst_data = reinterpret_cast(allocator_->MapBuffer(dst_data, CL_MAP_WRITE, nullptr, true)); + PackNHWCToNHWC4Fp32(src_data, dst_data, tensor->Batch(), tensor->Height() * tensor->Width(), tensor->Channel()); + tensor->SetData(dst_data); + tensor->SetFormat(dst_format); + allocator_->Free(src_data); + return RET_OK; + } else { + MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + << schema::EnumNameFormat(dst_format) << " in float32"; + return RET_ERROR; + } +} + +int OpenCLExecutor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format dst_format) { + MS_ASSERT(nullptr != tensor); + MS_ASSERT(4 == tensor->shape().size()); + // auto src_format = tensor->GetFormat(); + // todo + MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + << schema::EnumNameFormat(dst_format) << " in uint8"; + return RET_ERROR; +} +} // namespace mindspore::lite::opencl + diff --git a/mindspore/lite/src/runtime/opencl/opencl_executor.h b/mindspore/lite/src/runtime/opencl/opencl_executor.h new file mode 100644 index 00000000000..0b0cf4d43ad --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/opencl_executor.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPENCL_EXECUTOR_H_ +#define MINDSPORE_LITE_SRC_OPENCL_EXECUTOR_H_ + +#include +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/allocator.h" +#include "src/lite_kernel.h" +#include "src/executor.h" + +namespace mindspore::lite::opencl { +class OpenCLExecutor : Executor { + public: + OpenCLExecutor() : Executor() { + allocator_ = OpenCLRuntime::GetInstance()->GetAllocator(); + } + + int Prepare(const std::vector &kernels) { return 0; } + + int Run(std::vector &inputs, std::vector &outputs, + std::vector &kernels, Allocator *allocator = nullptr, + const kernel::KernelCallBack &before = nullptr, const kernel::KernelCallBack &after = nullptr); + + protected: + int TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format dst_format); + + int TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format dst_format); + + int TransformTensorLayout(tensor::Tensor *tensor, schema::Format dst_format); + + protected: + Context *context = nullptr; + OpenCLAllocator *allocator_; +}; + +} // namespace mindspore::lite::opencl +#endif + diff --git a/mindspore/lite/src/runtime/opencl/opencl_runtime.cc b/mindspore/lite/src/runtime/opencl/opencl_runtime.cc new file mode 100644 index 00000000000..c4993a25334 --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/opencl_runtime.cc @@ -0,0 +1,609 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/opencl/opencl_runtime.h" +#include +#include +#ifdef SHARING_MEM_WITH_OPENGL +#include +#endif +#include "src/runtime/kernel/opencl/utils.h" +#include "src/runtime/opencl/opencl_allocator.h" +#ifdef PROGRAM_WITH_IL +#include "src/backend/opencl/cl/program.inc" +#endif + +#ifndef ROUND_UP +#define ROUND_UP(x, y) ((static_cast(x) + static_cast(y) - (1)) / static_cast(y) * static_cast(y)) +#endif + +using mindspore::kernel::CLErrorCode; + +namespace mindspore::lite::opencl { + +std::map g_opencl_program_map; + +static std::mutex g_mtx; +static std::mutex g_init_mtx; + +// magic number +static std::map AdrenoSubGroup{ + {640, 128}, {630, 128}, {616, 128}, {612, 64}, {610, 64}, {540, 32}, {530, 32}, + {512, 32}, {510, 32}, {509, 32}, {506, 32}, {505, 32}, {405, 32}, {330, 16}, +}; + +#ifdef USE_OPENCL_WRAPPER +std::shared_ptr OpenCLWrapper::opencl_wrapper_singleton_ = nullptr; +#endif +std::shared_ptr OpenCLRuntime::opencl_runtime_singleton_ = nullptr; +bool OpenCLRuntime::init_done_ = false; + +OpenCLRuntime *OpenCLRuntime::GetInstance() { + std::unique_lock lck(g_mtx); + if (opencl_runtime_singleton_.get() == nullptr) { + opencl_runtime_singleton_.reset(new OpenCLRuntime()); + opencl_runtime_singleton_->Init(); + } + return opencl_runtime_singleton_.get(); +} + +void OpenCLRuntime::DeleteInstance() { + std::unique_lock lck(g_mtx); + init_done_ = false; + if (opencl_runtime_singleton_ != nullptr) { + opencl_runtime_singleton_.reset(); + opencl_runtime_singleton_ = nullptr; + } +} + +OpenCLRuntime::OpenCLRuntime() { default_build_opts_ = " -cl-mad-enable -cl-fast-relaxed-math -Werror"; } + +// Init will get platforms info, get devices info, create opencl context. +int OpenCLRuntime::Init() { + std::unique_lock lck(g_init_mtx); + + if (init_done_) { + return 0; + } + MS_LOG(INFO) << "OpenCL version: CL_TARGET_OPENCL_VERSION " << CL_TARGET_OPENCL_VERSION; + MS_LOG(INFO) << "CL_HPP_TARGET_OPENCL_VERSION " << CL_HPP_TARGET_OPENCL_VERSION; + MS_LOG(INFO) << "CL_HPP_MINIMUM_OPENCL_VERSION " << CL_HPP_MINIMUM_OPENCL_VERSION; + +#ifdef USE_OPENCL_WRAPPER + if (false == OpenCLWrapper::GetInstance()->LoadOpenCLLibrary()) { + MS_LOG(ERROR) << "Load OpenCL symbols failed!"; + return 1; + } +#endif // USE_OPENCL_WRAPPER + + std::vector platforms; + cl::Platform::get(&platforms); + if (platforms.size() == 0) { + MS_LOG(ERROR) << "OpenCL Platform not found!"; + return 1; + } + + // search GPU + std::vector devices; + for (auto it = platforms.begin(); it != platforms.end(); ++it) { + std::string platform_name; + it->getInfo(CL_PLATFORM_NAME, &platform_name); + it->getDevices(CL_DEVICE_TYPE_GPU, &devices); + MS_LOG(INFO) << "Platform (" << platform_name << ") has " << devices.size() << " GPUs"; + + if (devices.size() > 0) { + std::string device_name = devices[0].getInfo(); + MS_LOG(INFO) << "Find GPU: " << device_name.c_str(); + cl::Platform::setDefault(*it); + break; + } + } + + // not found, return error code. + if (devices.size() == 0) { + MS_LOG(ERROR) << "OpenCL Device not found!"; + return 1; + } + + device_ = std::make_shared(); + *device_ = devices[0]; + max_work_item_sizes_ = device_->getInfo(); + const std::string device_name = device_->getInfo(); + const std::string device_version = device_->getInfo(); + const std::string opencl_version = device_->getInfo(); + MS_LOG(INFO) << "Device name:\t" << device_name; + MS_LOG(INFO) << "Opencl version:\t" << device_version; + MS_LOG(INFO) << "Highest OpenCL c version:\t" << opencl_version; + MS_LOG(INFO) << "Max work item size:\t" + << max_work_item_sizes_[0] << " : " + << max_work_item_sizes_[1] << " : " + << max_work_item_sizes_[2]; + + gpu_info_ = ParseGpuInfo(device_name, device_version); + + cl_int err; +#if defined(SHARING_MEM_WITH_OPENGL) && (CL_HPP_TARGET_OPENCL_VERSION >= 120) + // create context from glcontext + MS_LOG(INFO) << "Create special opencl context to share with OpenGL"; + cl_context_properties context_prop[] = {CL_GL_CONTEXT_KHR, (cl_context_properties)eglGetCurrentContext(), + CL_EGL_DISPLAY_KHR, (cl_context_properties)eglGetCurrentDisplay(), 0}; + context_ = std::make_shared(std::vector{*device_}, context_prop, nullptr, nullptr, &err); + + if (err != CL_SUCCESS) { + MS_LOG(ERROR) << "Create special OpenCL context falied, Create common OpenCL context then."; + context_ = std::make_shared(std::vector{*device_}, nullptr, nullptr, nullptr, &err); + } +#else + MS_LOG(INFO) << "Create common opencl context"; + context_ = std::make_shared(std::vector{*device_}, nullptr, nullptr, nullptr, &err); +#endif + if (err != CL_SUCCESS) { + MS_LOG(ERROR) << "Context create failed: " << CLErrorCode(err); + return 1; + } + + // get cache size, compute units and frequency. + device_->getInfo(CL_DEVICE_GLOBAL_MEM_CACHE_SIZE, &global_memery_cachesize_); + device_->getInfo(CL_DEVICE_MAX_COMPUTE_UNITS, &compute_units_); + device_->getInfo(CL_DEVICE_MAX_CLOCK_FREQUENCY, &max_freq_); + cl_device_fp_config fp_config; + auto success = device_->getInfo(CL_DEVICE_HALF_FP_CONFIG, &fp_config); + support_fp16_ = CL_SUCCESS == success && fp_config > 0; + + err = device_->getInfo(CL_DEVICE_SVM_CAPABILITIES, &svm_capabilities_); + if (err != CL_SUCCESS || svm_capabilities_ == 0) { + svm_capabilities_ = 0; + MS_LOG(INFO) << "SVM capalibilties: " + << "NONE"; + } else { + if (svm_capabilities_ & CL_DEVICE_SVM_FINE_GRAIN_BUFFER) { + MS_LOG(INFO) << "SVM capalibilties: " + << "SVM_FINE_GRAIN_BUFFER"; + } + if (svm_capabilities_ & CL_DEVICE_SVM_COARSE_GRAIN_BUFFER) { + MS_LOG(INFO) << "SVM capalibilties: " + << "SVM_COARSE_GRAIN_BUFFER"; + } + if (svm_capabilities_ & CL_DEVICE_SVM_FINE_GRAIN_SYSTEM) { + MS_LOG(INFO) << "SVM capalibilties: " + << "SVM_COARSE_GRAIN_SYSTEM"; + } + if (svm_capabilities_ & CL_DEVICE_SVM_ATOMICS) { + MS_LOG(INFO) << "SVM capalibilties: " + << "SVM_ATOMICS"; + } + } + + MS_LOG(INFO) << "Global Mem Cache Size: " << global_memery_cachesize_; + MS_LOG(INFO) << "Compute Unit: " << compute_units_; + MS_LOG(INFO) << "Clock Frequency: " << max_freq_ << " MHz"; + + cl_command_queue_properties properties = 0; +#if MS_OPENCL_PROFILE + properties |= CL_QUEUE_PROFILING_ENABLE; +#endif + + default_command_queue_ = std::make_shared(*context_, *device_, properties, &err); + if (err != CL_SUCCESS) { + MS_LOG(ERROR) << "Command Queue create failed: " << CLErrorCode(err); + return 1; + } + + allocator_ = std::make_shared(); +#ifdef PROGRAM_WITH_IL + std::string flag = ""; + CreateProgramFromIL(g_program_binary, flag); +#endif + init_done_ = true; + MS_LOG(INFO) << "OpenCLRuntime init done!"; + + return 0; +} + +OpenCLRuntime::~OpenCLRuntime() { + program_map_.clear(); + // allocator_->Clear(); + allocator_.reset(); + default_command_queue_.reset(); + context_.reset(); + device_.reset(); +} + +cl::Context *OpenCLRuntime::Context() { return context_.get(); } + +cl::Device *OpenCLRuntime::Device() { return device_.get(); } + +uint64_t OpenCLRuntime::DeviceGlobalMemoryCacheSize() const { return global_memery_cachesize_; } + +int OpenCLRuntime::DeviceMaxWorkGroupSize() const { return max_work_group_size; } + +uint32_t OpenCLRuntime::DeviceComputeUnits() const { return compute_units_; } + +uint32_t OpenCLRuntime::DeviceMaxFreq() const { return max_freq_; } + +// get kernel enqueue max work group size +uint64_t OpenCLRuntime::GetMaxWorkGroupSize(const cl::Kernel &kernel) { + uint64_t max_workgroup_size = 0; + int ret = kernel.getWorkGroupInfo(*device_, CL_KERNEL_WORK_GROUP_SIZE, &max_workgroup_size); + if (ret != 0) max_workgroup_size = 0; + return max_workgroup_size; +} + +// opencl 2.0 can get SubGroupSize. +uint32_t OpenCLRuntime::GetSubGroupSize(const cl::Kernel &kernel, const cl::NDRange &range) { + uint32_t sub_group_size = 0; + + if (ADRENO == gpu_info_.type) { +#if CL_HPP_TARGET_OPENCL_VERSION >= 200 && CL_TARGET_OPENCL_VERSION >= 210 && defined(CL_HPP_USE_CL_SUB_GROUPS_KHR) + cl_int cl_ret; + sub_group_size = kernel.getSubGroupInfo(*device_, range, &cl_ret); + if (cl_ret != CL_SUCCESS) { + CHECK_CL_SUCCESS(cl_ret) + sub_group_size = 0; + } +#else + if (AdrenoSubGroup.find(gpu_info_.model_num) != AdrenoSubGroup.end()) { + sub_group_size = AdrenoSubGroup[gpu_info_.model_num]; + } +#endif + } + + return sub_group_size; +} + +GpuInfo OpenCLRuntime::GetGpuInfo() { return gpu_info_; } + +bool OpenCLRuntime::GetFp16Enable() const { return fp16_enable_; } + +// if support fp16, set fp16 will success. +bool OpenCLRuntime::SetFp16Enable(bool enable) { + fp16_enable_ = enable && support_fp16_; + return fp16_enable_ == enable; +} + +int OpenCLRuntime::BuildKernel(cl::Kernel &kernel, const std::string &program_name, const std::string &kernel_name, + const std::set &build_options) { + std::string build_options_str; + // set default macro + if (fp16_enable_) { + // fp16 enable, kernel will use half and read_imageh and write_imageh. + build_options_str = + "-DFLOAT=half -DFLOAT4=half4 -DRI_F=read_imageh " + "-DWI_F=write_imageh"; + } else { + // fp16 not enable, kernel will use float and read_imagef and write_imagef. + build_options_str = + "-DFLOAT=float -DFLOAT4=float4 -DRI_F=read_imagef " + "-DWI_F=write_imagef"; + } + + build_options_str = std::accumulate( + build_options.begin(), build_options.end(), build_options_str, + [](const std::string &options, const std::string &option) -> std::string { return options + " " + option; }); + build_options_str += default_build_opts_; + // program identifier = program_name + build_options + std::string build_program_key = program_name + build_options_str; + + auto build_program_it = program_map_.find(build_program_key); + cl::Program program; + // if search program identifier exist, then use it. + if (build_program_it != program_map_.end()) { + program = build_program_it->second; + } else { + // load program and build program + auto status = this->LoadProgram(program_name, &program); + if (!status) { + MS_LOG(ERROR) << "load program (" << program_name << ") failed!"; + return 1; + } + status = this->BuildProgram(build_options_str, &program); + if (!status) { + MS_LOG(ERROR) << program_name << " build failed!"; + return 1; + } + program_map_.emplace(build_program_key, program); + } + + cl_int err; + kernel = cl::Kernel(program, kernel_name.c_str(), &err); + if (err != CL_SUCCESS) { + MS_LOG(ERROR) << kernel_name << " Kernel create failed:" << CLErrorCode(err); + return 1; + } + return 0; +} + +// Run Kernel with 1D, 2D, 3D group size, and local size can be empty. +int OpenCLRuntime::RunKernel(const cl_kernel &kernel, const std::vector &global, + const std::vector &local, cl::CommandQueue *command_queue) { + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + MS_ASSERT(local.size() == 0 || local.size() == global.size()); + std::vector internal_global_ws = global; + for (size_t i = 0; i < local.size(); ++i) { + internal_global_ws[i] = ROUND_UP(global[i], local[i]); + } + + MS_LOG(INFO) << "global size: " << global.size() << ", local size: " << local.size(); + for (size_t i = 0; i < global.size(); i++) { + MS_LOG(DEBUG) << "global[" << i << "] = " << global[i]; + } + for (size_t i = 0; i < local.size(); i++) { + MS_LOG(DEBUG) << "local[" << i << "] = " << local[i]; + } + + cl::Event event; + cl_int error = CL_SUCCESS; + if (local.size() == 0) { + error = + clEnqueueNDRangeKernel((*command_queue)(), kernel, global.size(), 0, global.data(), nullptr, 0, nullptr, nullptr); + } else { + error = clEnqueueNDRangeKernel((*command_queue)(), kernel, global.size(), 0, global.data(), local.data(), 0, + nullptr, nullptr); + } + + if (error != CL_SUCCESS) { + MS_LOG(ERROR) << "Kernel execute failed:" << CLErrorCode(error); + return 1; + } + MS_LOG(INFO) << "RunKernel success!"; + return 0; +} + +// Run Kernel with 1D, 2D, 3D group size, and local size can be empty. +int OpenCLRuntime::RunKernel(const cl::Kernel &kernel, const std::vector &global, + const std::vector &local, cl::CommandQueue *command_queue) { + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + MS_ASSERT(local.size() == 0 || local.size() == global.size()); + std::vector internal_global_ws = global; + for (size_t i = 0; i < local.size(); ++i) { + internal_global_ws[i] = ROUND_UP(global[i], local[i]); + } + + MS_LOG(INFO) << "global size: " << global.size() << ", local size: " << local.size(); + for (size_t i = 0; i < global.size(); i++) { + MS_LOG(DEBUG) << "global[" << i << "] = " << global[i]; + } + for (size_t i = 0; i < local.size(); i++) { + MS_LOG(DEBUG) << "local[" << i << "] = " << local[i]; + } + + cl::Event event; + cl_int err = CL_SUCCESS; + + cl::NDRange global_range = cl::NullRange; + cl::NDRange local_range = cl::NullRange; + if (global.size() == 1) { + global_range = cl::NDRange(internal_global_ws[0]); + if (!local.empty()) { + local_range = cl::NDRange(local[0]); + } + } else if (global.size() == 2) { + global_range = cl::NDRange(internal_global_ws[0], internal_global_ws[1]); + if (!local.empty()) { + local_range = cl::NDRange(local[0], local[1]); + } + } else if (global.size() == 3) { + global_range = cl::NDRange(internal_global_ws[0], internal_global_ws[1], internal_global_ws[2]); + if (!local.empty()) { + local_range = cl::NDRange(local[0], local[1], local[2]); + } + } else { + MS_LOG(INFO) << "Not supported NDRange!"; + return 1; + } + + err = command_queue->enqueueNDRangeKernel(kernel, cl::NullRange, global_range, local_range, nullptr, &event); + + if (err != CL_SUCCESS) { + MS_LOG(ERROR) << "Kernel execute failed:" << CLErrorCode(err); + return 1; + } + MS_LOG(INFO) << "RunKernel success!"; +#if MS_OPENCL_PROFILE + event.wait(); + cl_ulong time_start; + cl_ulong time_end; + event.getProfilingInfo(CL_PROFILING_COMMAND_START, &time_start); + event.getProfilingInfo(CL_PROFILING_COMMAND_END, &time_end); + double nanoSeconds = time_end - time_start; + MS_LOG(INFO) << "OpenCl Execution time is: " << nanoSeconds / 1000000.0 << "ms"; +#endif + return 0; +} + +// get gpu divce type +GpuInfo OpenCLRuntime::ParseGpuInfo(std::string device_name, std::string device_version) { + GpuInfo info; + + if (device_name == "QUALCOMM Adreno(TM)") { + info.type = ADRENO; + sscanf(device_version.c_str(), "%*s%f%*s%d", &info.opencl_version, &info.model_num); + + } else if (device_name.find("Mali") != std::string::npos) { + info.type = MALI; + + // Mali type MALI-G or MALI_T + if (device_name.find("Mali-G") != std::string::npos) { + info.type = MALI_G; + sscanf(device_name.c_str(), "Mali-G%d", &info.model_num); + } else if (device_name.find("Mali-T") != std::string::npos) { + info.type = MALI_T; + sscanf(device_name.c_str(), "Mali-T%d", &info.model_num); + } + sscanf(device_version.c_str(), "%*s%f%*s", &info.opencl_version); + } + + return info; +} + +bool OpenCLRuntime::LoadSource(const std::string &program_name, const std::string &source) { + auto it_source = g_opencl_program_map.find(program_name); + if (it_source != g_opencl_program_map.end()) { + it_source->second = source; + } else { + g_opencl_program_map.emplace(program_name, source); + } + return true; +} + +// load program with program name. +bool OpenCLRuntime::LoadProgram(const std::string &program_name, cl::Program *program) { + auto it_source = g_opencl_program_map.find(program_name); + if (it_source != g_opencl_program_map.end()) { + cl::Program::Sources sources; + sources.push_back(it_source->second); + *program = cl::Program(*context_, sources); + return true; + } else { + MS_LOG(ERROR) << "Can't find kernel source !"; + return false; + } +} + +// build program with build options +bool OpenCLRuntime::BuildProgram(const std::string &build_options, cl::Program *program) { + cl_int ret = program->build({*device_}, build_options.c_str()); + if (ret != CL_SUCCESS) { + if (program->getBuildInfo(*device_) == CL_BUILD_ERROR) { + std::string build_log = program->getBuildInfo(*device_); + MS_LOG(ERROR) << "Program build log: " << build_log; + } + MS_LOG(ERROR) << "Build program failed: " << CLErrorCode(ret); + return false; + } + return true; +} + +bool OpenCLRuntime::CopyDeviceMemToHost(void *dst, const void *src, size_t size, cl::CommandQueue *command_queue, + bool sync) const { + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + cl_int cl_ret = CL_SUCCESS; + const cl::Buffer *buffer = static_cast(src); + if (command_queue != nullptr) { + cl_ret = command_queue->enqueueReadBuffer(*buffer, sync, 0, size, dst); + } + return cl_ret == CL_SUCCESS; +} + +bool OpenCLRuntime::CopyHostMemToDevice(const void *dst, const void *src, size_t size, cl::CommandQueue *command_queue, + bool sync) const { + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + cl_int cl_ret = CL_SUCCESS; + const cl::Buffer *buffer = static_cast(dst); + if (command_queue != nullptr) { + cl_ret = command_queue->enqueueWriteBuffer(*buffer, sync, 0, size, src); + } + return cl_ret == CL_SUCCESS; +} + +void *OpenCLRuntime::MapBuffer(const cl::Buffer buffer, int flags, size_t size, cl::CommandQueue *command_queue, + bool sync) const { + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + return command_queue->enqueueMapBuffer(buffer, sync, flags, 0, size); +} + +int OpenCLRuntime::MapBuffer(void *host_ptr, int flags, size_t size, cl::CommandQueue *command_queue, bool sync) const { + if (svm_capabilities_ & CL_DEVICE_SVM_FINE_GRAIN_BUFFER) { + return 0; + } + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + return command_queue->enqueueMapSVM(host_ptr, sync, flags, size); +} + +int OpenCLRuntime::UnmapBuffer(const cl::Buffer buffer, void *host_ptr, cl::CommandQueue *command_queue) const { + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + return command_queue->enqueueUnmapMemObject(buffer, host_ptr); +} + +int OpenCLRuntime::UnmapBuffer(void *host_ptr, cl::CommandQueue *command_queue) const { + if (svm_capabilities_ & CL_DEVICE_SVM_FINE_GRAIN_BUFFER) { + return 0; + } + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + return command_queue->enqueueUnmapSVM(host_ptr); +} + +bool OpenCLRuntime::SyncCommandQueue(cl::CommandQueue *command_queue) { + if (command_queue == nullptr) { + command_queue = default_command_queue_.get(); + } + cl_int ret = command_queue->finish(); + if (ret != CL_SUCCESS) { + MS_LOG(ERROR) << "Command queue sync failed: " << CLErrorCode(ret); + return 1; + } + return ret == CL_SUCCESS; +} + +int OpenCLRuntime::GetKernelMaxWorkGroupSize(cl_kernel kernel, cl_device_id device_id) { + size_t max_work_group_size; + cl_int err = clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_WORK_GROUP_SIZE, sizeof(size_t), + &max_work_group_size, nullptr); + if (err != CL_SUCCESS) { + MS_LOG(ERROR) << "Failed to get info CL_KERNEL_WORK_GROUP_SIZE " << CLErrorCode(err); + } + return static_cast(max_work_group_size); +} + +bool OpenCLRuntime::CreateKernelFromIL(cl_kernel &kernel, const std::string kernel_name) { + cl_int ret = CL_SUCCESS; + kernel = clCreateKernel(il_program_, kernel_name.c_str(), &ret); + if (ret != CL_SUCCESS) { + MS_LOG(ERROR) << "Create kernel with IL failed: " << CLErrorCode(ret); + } + return ret == CL_SUCCESS; +} + +// build program with IL +bool OpenCLRuntime::CreateProgramFromIL(const std::vector program_binary, const std::string flag) { +#if CL_HPP_TARGET_OPENCL_VERSION >= 210 + size_t program_length = program_binary.size(); + cl_int ret = CL_SUCCESS; + il_program_ = clCreateProgramWithIL((*context_)(), program_binary.data(), program_length, &ret); + if (ret != CL_SUCCESS) { + MS_LOG(ERROR) << "Create program with IL failed: " << CLErrorCode(ret); + return false; + } + + ret = clBuildProgram(il_program_, 1, &(*device_)(), flag.c_str(), NULL, NULL); + if (ret != CL_SUCCESS) { + MS_LOG(ERROR) << "Build program with IL failed: " << CLErrorCode(ret); + } + return ret == CL_SUCCESS; +#else + MS_LOG(ERROR) << "Create program with IL failed! The compute capabitity of device should be 2.1 and higher."; + return false; +#endif +} + +} // namespace mindspore::lite::opencl + diff --git a/mindspore/lite/src/runtime/opencl/opencl_runtime.h b/mindspore/lite/src/runtime/opencl/opencl_runtime.h new file mode 100644 index 00000000000..64593f553bd --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/opencl_runtime.h @@ -0,0 +1,155 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); +j* you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPENCL_RUNTIME_H_ +#define MINDSPORE_LITE_SRC_OPENCL_RUNTIME_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" +#include "src/runtime/opencl/opencl_wrapper.h" +#include "src/runtime/opencl/opencl_allocator.h" + +namespace mindspore::lite::opencl { + +enum GpuType { OTHER = 0, ADRENO = 1, MALI = 2, MALI_T = 3, MALI_G = 4 }; + +struct GpuInfo { + GpuType type = OTHER; + int model_num = 0; + float opencl_version = 0; +}; + +// Base GPU cache size used for computing local work group size. +const int32_t g_base_gpu_mem_cachesize = 16384; + +class OpenCLRuntime { + public: + static OpenCLRuntime *GetInstance(); + static void DeleteInstance(); + + ~OpenCLRuntime(); + OpenCLRuntime(const OpenCLRuntime &) = delete; + OpenCLRuntime &operator=(const OpenCLRuntime &) = delete; + + int Init(); + + cl::Context *Context(); + cl::Device *Device(); + OpenCLAllocator *GetAllocator() { return allocator_.get(); } + cl::CommandQueue *GetDefaultCommandQueue() { return default_command_queue_.get(); } + uint64_t DeviceGlobalMemoryCacheSize() const; + int DeviceMaxWorkGroupSize() const; + uint32_t DeviceComputeUnits() const; + uint32_t DeviceMaxFreq() const; + uint64_t GetMaxWorkGroupSize(const cl::Kernel &kernel); + uint32_t GetSubGroupSize(const cl::Kernel &kernel, const cl::NDRange &range = cl::NullRange); + GpuInfo GetGpuInfo(); + bool GetFp16Enable() const; + bool SetFp16Enable(bool enable); + const std::vector &GetWorkItemSize() { return max_work_item_sizes_; } + cl_device_svm_capabilities GetSVMCapabilities() const { return svm_capabilities_; } + + template + typename std::enable_if::value, cl_int>::type SetKernelArg(cl_kernel &kernel, uint32_t index, + const T value) { + if (svm_capabilities_) { + MS_LOG(DEBUG) << "Set kernel arg[" << index << "] SVM pointer " << value; + return clSetKernelArgSVMPointer(kernel, index, value); + } else { + cl::Buffer *buffer = reinterpret_cast(allocator_->GetDeviceBuffer(value)); + MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Buffer " << value; + return clSetKernelArg(kernel, index, sizeof((*buffer)()), &(*buffer)()); + } + } + + template + typename std::enable_if::value, cl_int>::type SetKernelArg(cl_kernel &kernel, uint32_t index, + const T value) { + return clSetKernelArg(kernel, index, sizeof(value), &value); + } + + template + int SetKernelArg(cl::Kernel &kernel, uint32_t index, const T &value) { + return SetKernelArg(kernel(), index, value); + } + + bool CreateProgramFromIL(const std::vector program_binary, const std::string flag); + bool CreateKernelFromIL(cl_kernel &kernel, const std::string kernel_name); + bool LoadSource(const std::string &program_name, const std::string &source); + int BuildKernel(cl::Kernel &kernel, const std::string &program_name, const std::string &kernel_name, + const std::set &build_options); + int RunKernel(const cl_kernel &kernel, const std::vector &global, const std::vector &local, + cl::CommandQueue *command_queue); + int RunKernel(const cl::Kernel &kernel, const std::vector &global, const std::vector &local, + cl::CommandQueue *command_queue); + bool CopyDeviceMemToHost(void *dst, const void *src, size_t size, cl::CommandQueue *command_queue = nullptr, + bool sync = false) const; + bool CopyHostMemToDevice(const void *dst, const void *src, size_t size, cl::CommandQueue *command_queue = nullptr, + bool sync = false) const; + void *MapBuffer(const cl::Buffer buffer, int map_flags, size_t size, cl::CommandQueue *command_queue = nullptr, + bool sync = false) const; + int MapBuffer(void *host_ptr, int map_flags, size_t size, cl::CommandQueue *command_queue = nullptr, + bool sync = false) const; + int UnmapBuffer(const cl::Buffer buffer, void *host_ptr, cl::CommandQueue *command_queue = nullptr) const; + int UnmapBuffer(void *host_ptr, cl::CommandQueue *command_queue = nullptr) const; + bool SyncCommandQueue(cl::CommandQueue *command_queue = nullptr); + + /** + * Get kernel max worker group size. + * @param kernel + * @param device_id + * @return max_work_group_size + */ + int GetKernelMaxWorkGroupSize(cl_kernel kernel, cl_device_id device_id); + + private: + OpenCLRuntime(); + GpuInfo ParseGpuInfo(std::string device_name, std::string device_version); + + bool LoadProgram(const std::string &program_name, cl::Program *program); + bool BuildProgram(const std::string &build_options, cl::Program *program); + + private: + static std::shared_ptr opencl_runtime_singleton_; + static bool init_done_; + std::shared_ptr default_command_queue_{nullptr}; + std::shared_ptr context_{nullptr}; + std::shared_ptr device_{nullptr}; + std::shared_ptr allocator_{nullptr}; + std::map program_map_{}; + cl_program il_program_{0}; + uint64_t global_memery_cachesize_{0}; + int max_work_group_size; + uint32_t compute_units_{0}; + uint32_t max_freq_{0}; + std::string default_build_opts_{""}; + GpuInfo gpu_info_; + bool support_fp16_{false}; + bool fp16_enable_{false}; + cl_device_svm_capabilities svm_capabilities_{0}; + std::vector max_work_item_sizes_; +}; + +} // namespace mindspore::lite::opencl + +#endif // MINDSPORE_LITE_SRC_OPENCL_RUNTIME_H_ + diff --git a/mindspore/lite/src/runtime/opencl/opencl_wrapper.cc b/mindspore/lite/src/runtime/opencl/opencl_wrapper.cc new file mode 100644 index 00000000000..084afc344a0 --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/opencl_wrapper.cc @@ -0,0 +1,683 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef USE_OPENCL_WRAPPER + +#include "src/runtime/opencl/opencl_wrapper.h" +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" +#include "src/runtime/opencl/opencl_runtime.h" + +namespace mindspore::lite::opencl { + +// default opencl library path +static const std::vector g_opencl_library_paths = { +#if defined(__APPLE__) || defined(__MACOSX) + "libOpenCL.so", "/System/Library/Frameworks/OpenCL.framework/OpenCL" +#elif defined(__ANDROID__) +#if defined(__aarch64__) + // Mali + "/system/vendor/lib64/egl/libGLES_mali.so", + "/system/lib64/egl/libGLES_mali.so", + // Qualcomm Adreno + "/system/vendor/lib64/libOpenCL.so", + "/system/lib64/libOpenCL.so", +#else + // Qualcomm Adreno + "/system/vendor/lib/libOpenCL.so", "/system/lib/libOpenCL.so", + // Mali + "/system/vendor/lib/egl/libGLES_mali.so", "/system/lib/egl/libGLES_mali.so", + // other + "/system/vendor/lib/libPVROCL.so", "/data/data/org.pocl.libs/files/lib/libpocl.so" +#endif + "libOpenCL.so", + "libGLES_mali.so", + "libmali.so", +#elif defined(__linux__) + "/usr/lib/libOpenCL.so", + "/usr/local/lib/libOpenCL.so", + "/usr/local/lib/libpocl.so", + "/usr/lib64/libOpenCL.so", + "/usr/lib32/libOpenCL.so", + "libOpenCL.so", + // intel + "/opt/intel/system_studio_2020/opencl/SDK/lib64/libOpenCL.so", +#endif +}; + +OpenCLWrapper *OpenCLWrapper::GetInstance() { + static std::once_flag opencl_wrapper_once; + std::call_once(opencl_wrapper_once, + []() { opencl_wrapper_singleton_ = std::shared_ptr(new OpenCLWrapper()); }); + + return opencl_wrapper_singleton_.get(); +} + +OpenCLWrapper::OpenCLWrapper() {} + +OpenCLWrapper::~OpenCLWrapper() { + if (nullptr == opencl_wrapper_singleton_.get()) return; + opencl_wrapper_singleton_->UnLoadOpenCLLibrary(); +} + +// load default library path +bool OpenCLWrapper::LoadOpenCLLibrary() { + if (handle_ != nullptr) { + return true; + } + for (const auto &lib_path : g_opencl_library_paths) { + if (LoadLibraryFromPath(lib_path)) { + MS_LOG(DEBUG) << "Find a OpenCL dynamic library : " << lib_path; + return true; + } + } + return false; +} + +bool OpenCLWrapper::UnLoadOpenCLLibrary() { + if (handle_ != nullptr) { + if (dlclose(handle_) != 0) { + return false; + } + handle_ = nullptr; + return true; + } + return true; +} + +bool OpenCLWrapper::LoadLibraryFromPath(const std::string &library_path) { + handle_ = dlopen(library_path.c_str(), RTLD_NOW | RTLD_LOCAL); + if (handle_ == nullptr) { + return false; + } + +// load function ptr use dlopen and dlsym. +#define LOAD_OPENCL_FUNCTION_PTR(func_name) \ + func_name = reinterpret_cast(dlsym(handle_, #func_name)); \ + if (func_name == nullptr) { \ + MS_LOG(ERROR) << "load func (" << #func_name << ") from (" << library_path << ") failed!"; \ + return false; \ + } + + LOAD_OPENCL_FUNCTION_PTR(clGetPlatformIDs); + LOAD_OPENCL_FUNCTION_PTR(clGetPlatformInfo); + LOAD_OPENCL_FUNCTION_PTR(clBuildProgram); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueNDRangeKernel); + LOAD_OPENCL_FUNCTION_PTR(clSetKernelArg); + LOAD_OPENCL_FUNCTION_PTR(clReleaseKernel); + LOAD_OPENCL_FUNCTION_PTR(clCreateProgramWithSource); + LOAD_OPENCL_FUNCTION_PTR(clCreateBuffer); + LOAD_OPENCL_FUNCTION_PTR(clCreateImage2D); + LOAD_OPENCL_FUNCTION_PTR(clCreateImage3D); + LOAD_OPENCL_FUNCTION_PTR(clRetainKernel); + LOAD_OPENCL_FUNCTION_PTR(clCreateKernel); + LOAD_OPENCL_FUNCTION_PTR(clGetProgramInfo); + LOAD_OPENCL_FUNCTION_PTR(clFlush); + LOAD_OPENCL_FUNCTION_PTR(clFinish); + LOAD_OPENCL_FUNCTION_PTR(clReleaseProgram); + LOAD_OPENCL_FUNCTION_PTR(clRetainContext); + LOAD_OPENCL_FUNCTION_PTR(clGetContextInfo); + LOAD_OPENCL_FUNCTION_PTR(clCreateProgramWithBinary); + LOAD_OPENCL_FUNCTION_PTR(clCreateCommandQueue); + LOAD_OPENCL_FUNCTION_PTR(clGetCommandQueueInfo); + LOAD_OPENCL_FUNCTION_PTR(clReleaseCommandQueue); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueMapBuffer); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueMapImage); + LOAD_OPENCL_FUNCTION_PTR(clRetainProgram); + LOAD_OPENCL_FUNCTION_PTR(clGetProgramBuildInfo); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueReadBuffer); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueWriteBuffer); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueReadImage); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueWriteImage); + LOAD_OPENCL_FUNCTION_PTR(clWaitForEvents); + LOAD_OPENCL_FUNCTION_PTR(clReleaseEvent); + LOAD_OPENCL_FUNCTION_PTR(clCreateContext); + LOAD_OPENCL_FUNCTION_PTR(clCreateContextFromType); + LOAD_OPENCL_FUNCTION_PTR(clReleaseContext); + LOAD_OPENCL_FUNCTION_PTR(clRetainCommandQueue); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueUnmapMemObject); + LOAD_OPENCL_FUNCTION_PTR(clRetainMemObject); + LOAD_OPENCL_FUNCTION_PTR(clReleaseMemObject); + LOAD_OPENCL_FUNCTION_PTR(clGetDeviceInfo); + LOAD_OPENCL_FUNCTION_PTR(clGetDeviceIDs); + LOAD_OPENCL_FUNCTION_PTR(clRetainEvent); + LOAD_OPENCL_FUNCTION_PTR(clGetKernelWorkGroupInfo); + LOAD_OPENCL_FUNCTION_PTR(clGetEventInfo); + LOAD_OPENCL_FUNCTION_PTR(clGetEventProfilingInfo); + LOAD_OPENCL_FUNCTION_PTR(clGetImageInfo); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueCopyImage); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueCopyBufferToImage); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueCopyImageToBuffer); +#if CL_HPP_TARGET_OPENCL_VERSION >= 120 + LOAD_OPENCL_FUNCTION_PTR(clRetainDevice); + LOAD_OPENCL_FUNCTION_PTR(clReleaseDevice); + LOAD_OPENCL_FUNCTION_PTR(clCreateImage); +#endif +#if CL_HPP_TARGET_OPENCL_VERSION >= 200 + // LOAD_OPENCL_FUNCTION_PTR(clGetKernelSubGroupInfoKHR); + LOAD_OPENCL_FUNCTION_PTR(clCreateCommandQueueWithProperties); + LOAD_OPENCL_FUNCTION_PTR(clGetExtensionFunctionAddress); + LOAD_OPENCL_FUNCTION_PTR(clSVMAlloc); + LOAD_OPENCL_FUNCTION_PTR(clSVMFree); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueSVMMap); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueSVMUnmap); + LOAD_OPENCL_FUNCTION_PTR(clSetKernelArgSVMPointer); +#ifdef PROGRAM_WITH_IL + LOAD_OPENCL_FUNCTION_PTR(clCreateProgramWithIL); +#endif +#endif + +#undef LOAD_OPENCL_FUNCTION_PTR + + return true; +} + +} // namespace mindspore::lite::opencl + +// clGetPlatformIDs wrapper, use OpenCLWrapper function. use OpenCLWrapper function. +cl_int clGetPlatformIDs(cl_uint num_entries, cl_platform_id *platforms, cl_uint *num_platforms) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetPlatformIDs; + MS_ASSERT(func != nullptr); + return func(num_entries, platforms, num_platforms); +} + +// clGetPlatformInfo wrapper, use OpenCLWrapper function. use OpenCLWrapper function. +cl_int clGetPlatformInfo(cl_platform_id platform, cl_platform_info param_name, size_t param_value_size, + void *param_value, size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetPlatformInfo; + MS_ASSERT(func != nullptr); + return func(platform, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clGetDeviceIDs wrapper, use OpenCLWrapper function. +cl_int clGetDeviceIDs(cl_platform_id platform, cl_device_type device_type, cl_uint num_entries, cl_device_id *devices, + cl_uint *num_devices) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetDeviceIDs; + MS_ASSERT(func != nullptr); + return func(platform, device_type, num_entries, devices, num_devices); +} + +// clGetDeviceInfo wrapper, use OpenCLWrapper function. +cl_int clGetDeviceInfo(cl_device_id device, cl_device_info param_name, size_t param_value_size, void *param_value, + size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetDeviceInfo; + MS_ASSERT(func != nullptr); + return func(device, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clCreateContext wrapper, use OpenCLWrapper function. +cl_context clCreateContext(const cl_context_properties *properties, cl_uint num_devices, const cl_device_id *devices, + void(CL_CALLBACK *pfn_notify)(const char *, const void *, size_t, void *), void *user_data, + cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateContext; + MS_ASSERT(func != nullptr); + return func(properties, num_devices, devices, pfn_notify, user_data, errcode_ret); +} + +// clCreateContextFromType wrapper, use OpenCLWrapper function. +cl_context clCreateContextFromType(const cl_context_properties *properties, cl_device_type device_type, + void(CL_CALLBACK *pfn_notify)(const char *, const void *, size_t, void *), + void *user_data, cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateContextFromType; + MS_ASSERT(func != nullptr); + return func(properties, device_type, pfn_notify, user_data, errcode_ret); +} + +// clRetainContext wrapper, use OpenCLWrapper function. +cl_int clRetainContext(cl_context context) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clRetainContext; + MS_ASSERT(func != nullptr); + return func(context); +} + +// clReleaseContext wrapper, use OpenCLWrapper function. +cl_int clReleaseContext(cl_context context) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clReleaseContext; + MS_ASSERT(func != nullptr); + return func(context); +} + +// clGetContextInfo wrapper, use OpenCLWrapper function. +cl_int clGetContextInfo(cl_context context, cl_context_info param_name, size_t param_value_size, void *param_value, + size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetContextInfo; + MS_ASSERT(func != nullptr); + return func(context, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clCreateProgramWithSource wrapper, use OpenCLWrapper function. +cl_program clCreateProgramWithSource(cl_context context, cl_uint count, const char **strings, const size_t *lengths, + cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateProgramWithSource; + MS_ASSERT(func != nullptr); + return func(context, count, strings, lengths, errcode_ret); +} + +// clGetProgramInfo wrapper, use OpenCLWrapper function. +cl_int clGetProgramInfo(cl_program program, cl_program_info param_name, size_t param_value_size, void *param_value, + size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetProgramInfo; + MS_ASSERT(func != nullptr); + return func(program, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clGetProgramBuildInfo wrapper, use OpenCLWrapper function. +cl_int clGetProgramBuildInfo(cl_program program, cl_device_id device, cl_program_build_info param_name, + size_t param_value_size, void *param_value, size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetProgramBuildInfo; + MS_ASSERT(func != nullptr); + return func(program, device, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clRetainProgram wrapper, use OpenCLWrapper function. +cl_int clRetainProgram(cl_program program) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clRetainProgram; + MS_ASSERT(func != nullptr); + return func(program); +} + +// clReleaseProgram wrapper, use OpenCLWrapper function. +cl_int clReleaseProgram(cl_program program) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clReleaseProgram; + MS_ASSERT(func != nullptr); + return func(program); +} + +// clBuildProgram wrapper, use OpenCLWrapper function. +cl_int clBuildProgram(cl_program program, cl_uint num_devices, const cl_device_id *device_list, const char *options, + void(CL_CALLBACK *pfn_notify)(cl_program program, void *user_data), void *user_data) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clBuildProgram; + MS_ASSERT(func != nullptr); + return func(program, num_devices, device_list, options, pfn_notify, user_data); +} + +// clCreateKernel wrapper, use OpenCLWrapper function. +cl_kernel clCreateKernel(cl_program program, const char *kernelName, cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateKernel; + MS_ASSERT(func != nullptr); + return func(program, kernelName, errcode_ret); +} + +// clRetainKernel wrapper, use OpenCLWrapper function. +cl_int clRetainKernel(cl_kernel kernel) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clRetainKernel; + MS_ASSERT(func != nullptr); + return func(kernel); +} + +// clReleaseKernel wrapper, use OpenCLWrapper function. +cl_int clReleaseKernel(cl_kernel kernel) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clReleaseKernel; + MS_ASSERT(func != nullptr); + return func(kernel); +} + +// clSetKernelArg wrapper, use OpenCLWrapper function. +cl_int clSetKernelArg(cl_kernel kernel, cl_uint arg_index, size_t arg_size, const void *arg_value) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clSetKernelArg; + MS_ASSERT(func != nullptr); + return func(kernel, arg_index, arg_size, arg_value); +} + +// clCreateBuffer wrapper, use OpenCLWrapper function. +cl_mem clCreateBuffer(cl_context context, cl_mem_flags flags, size_t size, void *host_ptr, cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateBuffer; + MS_ASSERT(func != nullptr); + return func(context, flags, size, host_ptr, errcode_ret); +} + +// clRetainMemObject wrapper, use OpenCLWrapper function. +cl_int clRetainMemObject(cl_mem memobj) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clRetainMemObject; + MS_ASSERT(func != nullptr); + return func(memobj); +} + +// clReleaseMemObject wrapper, use OpenCLWrapper function. +cl_int clReleaseMemObject(cl_mem memobj) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clReleaseMemObject; + MS_ASSERT(func != nullptr); + return func(memobj); +} + +// clGetImageInfo wrapper, use OpenCLWrapper function. +cl_int clGetImageInfo(cl_mem image, cl_image_info param_name, size_t param_value_size, void *param_value, + size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetImageInfo; + MS_ASSERT(func != nullptr); + return func(image, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clRetainCommandQueue wrapper, use OpenCLWrapper function. +cl_int clRetainCommandQueue(cl_command_queue command_queue) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clRetainCommandQueue; + MS_ASSERT(func != nullptr); + return func(command_queue); +} + +// clReleaseCommandQueue wrapper, use OpenCLWrapper function. +cl_int clReleaseCommandQueue(cl_command_queue command_queue) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clReleaseCommandQueue; + MS_ASSERT(func != nullptr); + return func(command_queue); +} + +// clEnqueueReadBuffer wrapper, use OpenCLWrapper function. +cl_int clEnqueueReadBuffer(cl_command_queue command_queue, cl_mem buffer, cl_bool blocking_read, size_t offset, + size_t size, void *ptr, cl_uint num_events_in_wait_list, const cl_event *event_wait_list, + cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueReadBuffer; + MS_ASSERT(func != nullptr); + return func(command_queue, buffer, blocking_read, offset, size, ptr, num_events_in_wait_list, event_wait_list, event); +} + +// clEnqueueWriteBuffer wrapper, use OpenCLWrapper function. +cl_int clEnqueueWriteBuffer(cl_command_queue command_queue, cl_mem buffer, cl_bool blocking_write, size_t offset, + size_t size, const void *ptr, cl_uint num_events_in_wait_list, + const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueWriteBuffer; + MS_ASSERT(func != nullptr); + return func(command_queue, buffer, blocking_write, offset, size, ptr, num_events_in_wait_list, event_wait_list, + event); +} + +// clEnqueueWriteImage wrapper, use OpenCLWrapper function. +cl_int clEnqueueWriteImage(cl_command_queue command_queue, cl_mem image, cl_bool blocking_write, const size_t *origin, + const size_t *region, size_t input_row_pitch, size_t input_slice_pitch, const void *ptr, + cl_uint num_events_in_wait_list, const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueWriteImage; + MS_ASSERT(func != nullptr); + return func(command_queue, image, blocking_write, origin, region, input_row_pitch, input_slice_pitch, ptr, + num_events_in_wait_list, event_wait_list, event); +} + +// clEnqueueReadImage wrapper, use OpenCLWrapper function. +cl_int clEnqueueReadImage(cl_command_queue command_queue, cl_mem image, cl_bool blocking_read, const size_t *origin, + const size_t *region, size_t row_pitch, size_t slice_pitch, void *ptr, + cl_uint num_events_in_wait_list, const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueReadImage; + MS_ASSERT(func != nullptr); + return func(command_queue, image, blocking_read, origin, region, row_pitch, slice_pitch, ptr, num_events_in_wait_list, + event_wait_list, event); +} + +// clEnqueueMapBuffer wrapper, use OpenCLWrapper function. +void *clEnqueueMapBuffer(cl_command_queue command_queue, cl_mem buffer, cl_bool blocking_map, cl_map_flags map_flags, + size_t offset, size_t size, cl_uint num_events_in_wait_list, const cl_event *event_wait_list, + cl_event *event, cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueMapBuffer; + MS_ASSERT(func != nullptr); + return func(command_queue, buffer, blocking_map, map_flags, offset, size, num_events_in_wait_list, event_wait_list, + event, errcode_ret); +} + +// clEnqueueMapImage wrapper, use OpenCLWrapper function. +void *clEnqueueMapImage(cl_command_queue command_queue, cl_mem image, cl_bool blocking_map, cl_map_flags map_flags, + const size_t *origin, const size_t *region, size_t *image_row_pitch, size_t *image_slice_pitch, + cl_uint num_events_in_wait_list, const cl_event *event_wait_list, cl_event *event, + cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueMapImage; + MS_ASSERT(func != nullptr); + return func(command_queue, image, blocking_map, map_flags, origin, region, image_row_pitch, image_slice_pitch, + num_events_in_wait_list, event_wait_list, event, errcode_ret); +} + +// clEnqueueUnmapMemObject wrapper, use OpenCLWrapper function. +cl_int clEnqueueUnmapMemObject(cl_command_queue command_queue, cl_mem memobj, void *mapped_ptr, + cl_uint num_events_in_wait_list, const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueUnmapMemObject; + MS_ASSERT(func != nullptr); + return func(command_queue, memobj, mapped_ptr, num_events_in_wait_list, event_wait_list, event); +} + +// clGetKernelWorkGroupInfo wrapper, use OpenCLWrapper function. +cl_int clGetKernelWorkGroupInfo(cl_kernel kernel, cl_device_id device, cl_kernel_work_group_info param_name, + size_t param_value_size, void *param_value, size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetKernelWorkGroupInfo; + MS_ASSERT(func != nullptr); + return func(kernel, device, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clGetEventProfilingInfo wrapper, use OpenCLWrapper function. +cl_int clGetEventProfilingInfo(cl_event event, cl_profiling_info param_name, size_t param_value_size, void *param_value, + size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetEventProfilingInfo; + MS_ASSERT(func != nullptr); + return func(event, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clEnqueueNDRangeKernel wrapper, use OpenCLWrapper function. +cl_int clEnqueueNDRangeKernel(cl_command_queue command_queue, cl_kernel kernel, cl_uint work_dim, + const size_t *global_work_offset, const size_t *global_work_size, + const size_t *local_work_size, cl_uint num_events_in_wait_list, + const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueNDRangeKernel; + MS_ASSERT(func != nullptr); + return func(command_queue, kernel, work_dim, global_work_offset, global_work_size, local_work_size, + num_events_in_wait_list, event_wait_list, event); +} + +// clWaitForEvents wrapper, use OpenCLWrapper function. +cl_int clWaitForEvents(cl_uint num_events, const cl_event *event_list) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clWaitForEvents; + MS_ASSERT(func != nullptr); + return func(num_events, event_list); +} + +// clRetainEvent wrapper, use OpenCLWrapper function. +cl_int clRetainEvent(cl_event event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clRetainEvent; + MS_ASSERT(func != nullptr); + return func(event); +} + +// clReleaseEvent wrapper, use OpenCLWrapper function. +cl_int clReleaseEvent(cl_event event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clReleaseEvent; + MS_ASSERT(func != nullptr); + return func(event); +} + +// clGetEventInfo wrapper, use OpenCLWrapper function. +cl_int clGetEventInfo(cl_event event, cl_event_info param_name, size_t param_value_size, void *param_value, + size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetEventInfo; + MS_ASSERT(func != nullptr); + return func(event, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clFlush wrapper, use OpenCLWrapper function. +cl_int clFlush(cl_command_queue command_queue) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clFlush; + MS_ASSERT(func != nullptr); + return func(command_queue); +} + +// clFinish wrapper, use OpenCLWrapper function. +cl_int clFinish(cl_command_queue command_queue) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clFinish; + MS_ASSERT(func != nullptr); + return func(command_queue); +} + +// clCreateImage2D wrapper, use OpenCLWrapper function. +cl_mem clCreateImage2D(cl_context context, cl_mem_flags flags, const cl_image_format *image_format, size_t imageWidth, + size_t imageHeight, size_t image_row_pitch, void *host_ptr, cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateImage2D; + MS_ASSERT(func != nullptr); + return func(context, flags, image_format, imageWidth, imageHeight, image_row_pitch, host_ptr, errcode_ret); +} + +// clCreateImage3D wrapper, use OpenCLWrapper function. +cl_mem clCreateImage3D(cl_context context, cl_mem_flags flags, const cl_image_format *image_format, size_t imageWidth, + size_t imageHeight, size_t imageDepth, size_t image_row_pitch, size_t image_slice_pitch, + void *host_ptr, cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateImage3D; + MS_ASSERT(func != nullptr); + return func(context, flags, image_format, imageWidth, imageHeight, imageDepth, image_row_pitch, image_slice_pitch, + host_ptr, errcode_ret); +} + +// clCreateCommandQueue wrapper, use OpenCLWrapper function. +cl_command_queue clCreateCommandQueue(cl_context context, cl_device_id device, cl_command_queue_properties properties, + cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateCommandQueue; + MS_ASSERT(func != nullptr); + return func(context, device, properties, errcode_ret); +} + +// clGetCommandQueueInfo wrapper, use OpenCLWrapper function. +cl_int clGetCommandQueueInfo(cl_command_queue command_queue, cl_command_queue_info param_name, size_t param_value_size, + void *param_value, size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetCommandQueueInfo; + MS_ASSERT(func != nullptr); + return func(command_queue, param_name, param_value_size, param_value, param_value_size_ret); +} + +// clEnqueueCopyImage wrapper, use OpenCLWrapper function. +cl_int clEnqueueCopyImage(cl_command_queue queue, cl_mem src_image, cl_mem dst_image, const size_t *src_origin, + const size_t *dst_origin, const size_t *region, cl_uint num_events_in_wait_list, + const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueCopyImage; + MS_ASSERT(func != nullptr); + return func(queue, src_image, dst_image, src_origin, dst_origin, region, num_events_in_wait_list, event_wait_list, + event); +} + +// clEnqueueCopyBufferToImage wrapper, use OpenCLWrapper function. +cl_int clEnqueueCopyBufferToImage(cl_command_queue command_queue, cl_mem src_buffer, cl_mem dst_image, + size_t src_offset, const size_t *dst_origin, const size_t *region, + cl_uint num_events_in_wait_list, const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueCopyBufferToImage; + MS_ASSERT(func != nullptr); + return func(command_queue, src_buffer, dst_image, src_offset, dst_origin, region, num_events_in_wait_list, + event_wait_list, event); +} + +// clEnqueueCopyImageToBuffer wrapper, use OpenCLWrapper function. +cl_int clEnqueueCopyImageToBuffer(cl_command_queue command_queue, cl_mem src_image, cl_mem dst_buffer, + const size_t *src_origin, const size_t *region, size_t dst_offset, + cl_uint num_events_in_wait_list, const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueCopyImageToBuffer; + MS_ASSERT(func != nullptr); + return func(command_queue, src_image, dst_buffer, src_origin, region, dst_offset, num_events_in_wait_list, + event_wait_list, event); +} + +#if CL_HPP_TARGET_OPENCL_VERSION >= 120 + +// clRetainDevice wrapper, use OpenCLWrapper function. +cl_int clRetainDevice(cl_device_id device) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clRetainDevice; + MS_ASSERT(func != nullptr); + return func(device); +} + +// clReleaseDevice wrapper, use OpenCLWrapper function. +cl_int clReleaseDevice(cl_device_id device) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clReleaseDevice; + MS_ASSERT(func != nullptr); + return func(device); +} + +// clCreateImage wrapper, use OpenCLWrapper function. +cl_mem clCreateImage(cl_context context, cl_mem_flags flags, const cl_image_format *image_format, + const cl_image_desc *image_desc, void *host_ptr, cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateImage; + MS_ASSERT(func != nullptr); + return func(context, flags, image_format, image_desc, host_ptr, errcode_ret); +} + +#endif + +#if CL_HPP_TARGET_OPENCL_VERSION >= 200 +#if 0 +// clGetKernelSubGroupInfoKHR wrapper, use OpenCLWrapper function. +cl_int clGetKernelSubGroupInfoKHR(cl_kernel kernel, cl_device_id device, cl_kernel_sub_group_info param_name, + size_t input_value_size, const void *input_value, size_t param_value_size, + void *param_value, size_t *param_value_size_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetKernelSubGroupInfoKHR; + MS_ASSERT(func != nullptr); + return func(kernel, device, param_name, input_value_size, input_value, param_value_size, param_value, + param_value_size_ret); +} +#endif + +// clCreateCommandQueueWithProperties wrapper, use OpenCLWrapper function. +cl_command_queue clCreateCommandQueueWithProperties(cl_context context, cl_device_id device, + const cl_queue_properties *properties, cl_int *errcode_ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateCommandQueueWithProperties; + MS_ASSERT(func != nullptr); + return func(context, device, properties, errcode_ret); +} + +// clGetExtensionFunctionAddress wrapper, use OpenCLWrapper function. +void *clGetExtensionFunctionAddress(const char *func_name) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clGetExtensionFunctionAddress; + MS_ASSERT(func != nullptr); + return func(func_name); +} +// clCreateProgramWithIL wrapper, use OpenCLWrapper function. +cl_program clCreateProgramWithIL(cl_context context, const void *il, size_t length, cl_int *ret) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clCreateProgramWithIL; + MS_ASSERT(func != nullptr); + return func(context, il, length, ret); +} + +// clSVMAlloc wrapper, use OpenCLWrapper function. +void *clSVMAlloc(cl_context context, cl_mem_flags flags, size_t size, cl_uint align) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clSVMAlloc; + MS_ASSERT(func != nullptr); + return func(context, flags, size, align); +} + +// clSVMFree wrapper, use OpenCLWrapper function. +void clSVMFree(cl_context context, void *buffer) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clSVMFree; + MS_ASSERT(func != nullptr); + func(context, buffer); +} + +// clEnqueueSVMMap wrapper, use OpenCLWrapper function. +cl_int clEnqueueSVMMap(cl_command_queue command_queue, cl_bool blocking, cl_map_flags flags, void *host_ptr, + size_t size, cl_uint num_events_in_wait_list, const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueSVMMap; + MS_ASSERT(func != nullptr); + return func(command_queue, blocking, flags, host_ptr, size, num_events_in_wait_list, event_wait_list, event); +} + +// clEnqueueSVMUnmap wrapper, use OpenCLWrapper function. +cl_int clEnqueueSVMUnmap(cl_command_queue command_queue, void *host_ptr, cl_uint num_events_in_wait_list, + const cl_event *event_wait_list, cl_event *event) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clEnqueueSVMUnmap; + MS_ASSERT(func != nullptr); + return func(command_queue, host_ptr, num_events_in_wait_list, event_wait_list, event); +} + +// clSetKernelArgSVMPointer wrapper, use OpenCLWrapper function. +cl_int clSetKernelArgSVMPointer(cl_kernel kernel, cl_uint index, const void *host_ptr) { + auto func = mindspore::lite::opencl::OpenCLWrapper::GetInstance()->clSetKernelArgSVMPointer; + MS_ASSERT(func != nullptr); + return func(kernel, index, host_ptr); +} +#endif + +#endif // USE_OPENCL_WRAPPER + diff --git a/mindspore/lite/src/runtime/opencl/opencl_wrapper.h b/mindspore/lite/src/runtime/opencl/opencl_wrapper.h new file mode 100644 index 00000000000..d4f0d98f9a7 --- /dev/null +++ b/mindspore/lite/src/runtime/opencl/opencl_wrapper.h @@ -0,0 +1,240 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPENCL_WRAPPER_H_ +#define MINDSPORE_LITE_SRC_OPENCL_WRAPPER_H_ + +#include +#include +#include + +// support opencl min version is 1.1 +#ifndef CL_TARGET_OPENCL_VERSION +#define CL_TARGET_OPENCL_VERSION 210 +#endif +#ifndef CL_HPP_TARGET_OPENCL_VERSION +#define CL_HPP_TARGET_OPENCL_VERSION 210 +#endif +#ifndef CL_HPP_MINIMUM_OPENCL_VERSION +#define CL_HPP_MINIMUM_OPENCL_VERSION 110 +#endif + +#include "CL/cl2.hpp" + +#ifdef USE_OPENCL_WRAPPER + +namespace mindspore::lite::opencl { + +// This is a opencl function wrapper. +class OpenCLWrapper { + public: + static OpenCLWrapper *GetInstance(); + + ~OpenCLWrapper(); + OpenCLWrapper(const OpenCLWrapper &) = delete; + OpenCLWrapper &operator=(const OpenCLWrapper &) = delete; + + bool LoadOpenCLLibrary(); + bool UnLoadOpenCLLibrary(); + // get platfrom id + using clGetPlatformIDsFunc = cl_int (*)(cl_uint, cl_platform_id *, cl_uint *); + // get platform info + using clGetPlatformInfoFunc = cl_int (*)(cl_platform_id, cl_platform_info, size_t, void *, size_t *); + // build program + using clBuildProgramFunc = cl_int (*)(cl_program, cl_uint, const cl_device_id *, const char *, + void (*pfn_notify)(cl_program, void *), void *); + // enqueue run kernel + using clEnqueueNDRangeKernelFunc = cl_int (*)(cl_command_queue, cl_kernel, cl_uint, const size_t *, const size_t *, + const size_t *, cl_uint, const cl_event *, cl_event *); + // set kernel parameter + using clSetKernelArgFunc = cl_int (*)(cl_kernel, cl_uint, size_t, const void *); + using clRetainMemObjectFunc = cl_int (*)(cl_mem); + using clReleaseMemObjectFunc = cl_int (*)(cl_mem); + using clEnqueueUnmapMemObjectFunc = cl_int (*)(cl_command_queue, cl_mem, void *, cl_uint, const cl_event *, + cl_event *); + using clRetainCommandQueueFunc = cl_int (*)(cl_command_queue command_queue); + // create context + using clCreateContextFunc = cl_context (*)(const cl_context_properties *, cl_uint, const cl_device_id *, + void(CL_CALLBACK *)( // NOLINT(readability/casting) + const char *, const void *, size_t, void *), + void *, cl_int *); + using clEnqueueCopyImageFunc = cl_int (*)(cl_command_queue, cl_mem, cl_mem, const size_t *, const size_t *, + const size_t *, cl_uint, const cl_event *, cl_event *); + + using clCreateContextFromTypeFunc = cl_context (*)(const cl_context_properties *, cl_device_type, + void(CL_CALLBACK *)( // NOLINT(readability/casting) + const char *, const void *, size_t, void *), + void *, cl_int *); + using clReleaseContextFunc = cl_int (*)(cl_context); + using clWaitForEventsFunc = cl_int (*)(cl_uint, const cl_event *); + using clReleaseEventFunc = cl_int (*)(cl_event); + using clEnqueueWriteBufferFunc = cl_int (*)(cl_command_queue, cl_mem, cl_bool, size_t, size_t, const void *, cl_uint, + const cl_event *, cl_event *); + using clEnqueueWriteImageFunc = cl_int (*)(cl_command_queue, cl_mem, cl_bool, const size_t *, const size_t *, size_t, + size_t, const void *, cl_uint, const cl_event *, cl_event *); + using clEnqueueReadImageFunc = cl_int (*)(cl_command_queue, cl_mem, cl_bool, const size_t *, const size_t *, size_t, + size_t, void *, cl_uint, const cl_event *, cl_event *); + using clEnqueueReadBufferFunc = cl_int (*)(cl_command_queue, cl_mem, cl_bool, size_t, size_t, void *, cl_uint, + const cl_event *, cl_event *); + using clGetProgramBuildInfoFunc = cl_int (*)(cl_program, cl_device_id, cl_program_build_info, size_t, void *, + size_t *); + using clRetainProgramFunc = cl_int (*)(cl_program program); + using clEnqueueMapBufferFunc = void *(*)(cl_command_queue, cl_mem, cl_bool, cl_map_flags, size_t, size_t, cl_uint, + const cl_event *, cl_event *, cl_int *); + using clEnqueueMapImageFunc = void *(*)(cl_command_queue, cl_mem, cl_bool, cl_map_flags, const size_t *, + const size_t *, size_t *, size_t *, cl_uint, const cl_event *, cl_event *, + cl_int *); + using clCreateCommandQueueFunc = cl_command_queue(CL_API_CALL *)(cl_context, cl_device_id, + cl_command_queue_properties, cl_int *); + using clGetCommandQueueInfoFunc = cl_int (*)(cl_command_queue, cl_command_queue_info, size_t, void *, size_t *); + using clReleaseCommandQueueFunc = cl_int (*)(cl_command_queue); + using clCreateProgramWithBinaryFunc = cl_program (*)(cl_context, cl_uint, const cl_device_id *, const size_t *, + const unsigned char **, cl_int *, cl_int *); + using clRetainContextFunc = cl_int (*)(cl_context context); + using clGetContextInfoFunc = cl_int (*)(cl_context, cl_context_info, size_t, void *, size_t *); + using clReleaseProgramFunc = cl_int (*)(cl_program program); + using clFlushFunc = cl_int (*)(cl_command_queue command_queue); + using clFinishFunc = cl_int (*)(cl_command_queue command_queue); + using clGetProgramInfoFunc = cl_int (*)(cl_program, cl_program_info, size_t, void *, size_t *); + using clCreateKernelFunc = cl_kernel (*)(cl_program, const char *, cl_int *); + using clRetainKernelFunc = cl_int (*)(cl_kernel kernel); + using clCreateBufferFunc = cl_mem (*)(cl_context, cl_mem_flags, size_t, void *, cl_int *); + using clCreateImage2DFunc = cl_mem(CL_API_CALL *)(cl_context, cl_mem_flags, const cl_image_format *, size_t, size_t, + size_t, void *, cl_int *); + using clCreateImage3DFunc = cl_mem(CL_API_CALL *)(cl_context, cl_mem_flags, const cl_image_format *, size_t, size_t, + size_t, size_t, size_t, void *, cl_int *); + using clCreateProgramWithSourceFunc = cl_program (*)(cl_context, cl_uint, const char **, const size_t *, cl_int *); + using clReleaseKernelFunc = cl_int (*)(cl_kernel kernel); + using clGetDeviceInfoFunc = cl_int (*)(cl_device_id, cl_device_info, size_t, void *, size_t *); + using clGetDeviceIDsFunc = cl_int (*)(cl_platform_id, cl_device_type, cl_uint, cl_device_id *, cl_uint *); + using clRetainEventFunc = cl_int (*)(cl_event); + using clGetKernelWorkGroupInfoFunc = cl_int (*)(cl_kernel, cl_device_id, cl_kernel_work_group_info, size_t, void *, + size_t *); + using clGetEventInfoFunc = cl_int (*)(cl_event event, cl_event_info param_name, size_t param_value_size, + void *param_value, size_t *param_value_size_ret); + using clGetEventProfilingInfoFunc = cl_int (*)(cl_event event, cl_profiling_info param_name, size_t param_value_size, + void *param_value, size_t *param_value_size_ret); + using clGetImageInfoFunc = cl_int (*)(cl_mem, cl_image_info, size_t, void *, size_t *); + using clEnqueueCopyBufferToImageFunc = cl_int(CL_API_CALL *)(cl_command_queue, cl_mem, cl_mem, size_t, const size_t *, + const size_t *, cl_uint, const cl_event *, cl_event *); + using clEnqueueCopyImageToBufferFunc = cl_int(CL_API_CALL *)(cl_command_queue, cl_mem, cl_mem, const size_t *, + const size_t *, size_t, cl_uint, const cl_event *, + cl_event *); +#if CL_HPP_TARGET_OPENCL_VERSION >= 120 + using clRetainDeviceFunc = cl_int (*)(cl_device_id); + using clReleaseDeviceFunc = cl_int (*)(cl_device_id); + using clCreateImageFunc = cl_mem (*)(cl_context, cl_mem_flags, const cl_image_format *, const cl_image_desc *, void *, + cl_int *); +#endif +#if CL_HPP_TARGET_OPENCL_VERSION >= 200 + using clCreateProgramWithILFunc = cl_program (*)(cl_context, const void *, size_t, cl_int *); + using clSVMAllocFunc = void *(*)(cl_context, cl_mem_flags, size_t size, cl_uint); + using clSVMFreeFunc = void (*)(cl_context, void *); + using clEnqueueSVMMapFunc = cl_int (*)(cl_command_queue, cl_bool, cl_map_flags, void *, size_t, cl_uint, + const cl_event *, cl_event *); + using clEnqueueSVMUnmapFunc = cl_int (*)(cl_command_queue, void *, cl_uint, const cl_event *, cl_event *); + using clSetKernelArgSVMPointerFunc = cl_int (*)(cl_kernel, cl_uint, const void *); + // opencl 2.0 can get sub group info and wave size. + using clGetKernelSubGroupInfoKHRFunc = cl_int(CL_API_CALL *)(cl_kernel, cl_device_id, cl_kernel_sub_group_info, + size_t, const void *, size_t, void *, size_t *); + using clCreateCommandQueueWithPropertiesFunc = cl_command_queue(CL_API_CALL *)(cl_context, cl_device_id, + const cl_queue_properties *, cl_int *); + using clGetExtensionFunctionAddressFunc = void *(CL_API_CALL *)(const char *); +#endif + +#define CL_DEFINE_FUNC_PTR(func) func##Func func = nullptr + + CL_DEFINE_FUNC_PTR(clGetPlatformIDs); + CL_DEFINE_FUNC_PTR(clGetPlatformInfo); + CL_DEFINE_FUNC_PTR(clBuildProgram); + CL_DEFINE_FUNC_PTR(clEnqueueNDRangeKernel); + CL_DEFINE_FUNC_PTR(clSetKernelArg); + CL_DEFINE_FUNC_PTR(clReleaseKernel); + CL_DEFINE_FUNC_PTR(clCreateProgramWithSource); + CL_DEFINE_FUNC_PTR(clCreateBuffer); + CL_DEFINE_FUNC_PTR(clCreateImage2D); + CL_DEFINE_FUNC_PTR(clCreateImage3D); + CL_DEFINE_FUNC_PTR(clRetainKernel); + CL_DEFINE_FUNC_PTR(clCreateKernel); + CL_DEFINE_FUNC_PTR(clGetProgramInfo); + CL_DEFINE_FUNC_PTR(clFlush); + CL_DEFINE_FUNC_PTR(clFinish); + CL_DEFINE_FUNC_PTR(clReleaseProgram); + CL_DEFINE_FUNC_PTR(clRetainContext); + CL_DEFINE_FUNC_PTR(clGetContextInfo); + CL_DEFINE_FUNC_PTR(clCreateProgramWithBinary); + CL_DEFINE_FUNC_PTR(clCreateCommandQueue); + CL_DEFINE_FUNC_PTR(clGetCommandQueueInfo); + CL_DEFINE_FUNC_PTR(clReleaseCommandQueue); + CL_DEFINE_FUNC_PTR(clEnqueueMapBuffer); + CL_DEFINE_FUNC_PTR(clEnqueueMapImage); + CL_DEFINE_FUNC_PTR(clEnqueueCopyImage); + CL_DEFINE_FUNC_PTR(clRetainProgram); + CL_DEFINE_FUNC_PTR(clGetProgramBuildInfo); + CL_DEFINE_FUNC_PTR(clEnqueueReadBuffer); + CL_DEFINE_FUNC_PTR(clEnqueueWriteBuffer); + CL_DEFINE_FUNC_PTR(clEnqueueWriteImage); + CL_DEFINE_FUNC_PTR(clEnqueueReadImage); + CL_DEFINE_FUNC_PTR(clWaitForEvents); + CL_DEFINE_FUNC_PTR(clReleaseEvent); + CL_DEFINE_FUNC_PTR(clCreateContext); + CL_DEFINE_FUNC_PTR(clCreateContextFromType); + CL_DEFINE_FUNC_PTR(clReleaseContext); + CL_DEFINE_FUNC_PTR(clRetainCommandQueue); + CL_DEFINE_FUNC_PTR(clEnqueueUnmapMemObject); + CL_DEFINE_FUNC_PTR(clRetainMemObject); + CL_DEFINE_FUNC_PTR(clReleaseMemObject); + CL_DEFINE_FUNC_PTR(clGetDeviceInfo); + CL_DEFINE_FUNC_PTR(clGetDeviceIDs); + CL_DEFINE_FUNC_PTR(clRetainEvent); + CL_DEFINE_FUNC_PTR(clGetKernelWorkGroupInfo); + CL_DEFINE_FUNC_PTR(clGetEventInfo); + CL_DEFINE_FUNC_PTR(clGetEventProfilingInfo); + CL_DEFINE_FUNC_PTR(clGetImageInfo); + CL_DEFINE_FUNC_PTR(clEnqueueCopyBufferToImage); + CL_DEFINE_FUNC_PTR(clEnqueueCopyImageToBuffer); +#if CL_HPP_TARGET_OPENCL_VERSION >= 120 + CL_DEFINE_FUNC_PTR(clRetainDevice); + CL_DEFINE_FUNC_PTR(clReleaseDevice); + CL_DEFINE_FUNC_PTR(clCreateImage); +#endif +#if CL_HPP_TARGET_OPENCL_VERSION >= 200 + CL_DEFINE_FUNC_PTR(clGetKernelSubGroupInfoKHR); + CL_DEFINE_FUNC_PTR(clCreateCommandQueueWithProperties); + CL_DEFINE_FUNC_PTR(clGetExtensionFunctionAddress); + CL_DEFINE_FUNC_PTR(clCreateProgramWithIL); + CL_DEFINE_FUNC_PTR(clSVMAlloc); + CL_DEFINE_FUNC_PTR(clSVMFree); + CL_DEFINE_FUNC_PTR(clEnqueueSVMMap); + CL_DEFINE_FUNC_PTR(clEnqueueSVMUnmap); + CL_DEFINE_FUNC_PTR(clSetKernelArgSVMPointer); +#endif + +#undef TNN_CL_DEFINE_FUNC_PTR + + private: + OpenCLWrapper(); + bool LoadLibraryFromPath(const std::string &path); + + private: + static std::shared_ptr opencl_wrapper_singleton_; + void *handle_ = nullptr; +}; + +} // namespace mindspore::lite::opencl +#endif // USE_OPENCL_WRAPPER +#endif // MINDSPORE_LITE_SRC_OPENCL_WRAPPER_H_ + diff --git a/mindspore/lite/src/runtime/runtime_api.cc b/mindspore/lite/src/runtime/runtime_api.cc new file mode 100644 index 00000000000..460ae4b07a8 --- /dev/null +++ b/mindspore/lite/src/runtime/runtime_api.cc @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "src/runtime/runtime_api.h" +#include "src/runtime/workspace_pool.h" +#include "src/runtime/thread_pool.h" +#include "utils/log_adapter.h" + +static std::mutex gWorkspaceMutex; +#ifdef __cplusplus +extern "C" { +#endif +void LiteAPISetLastError(const char *msg) { + MS_LOG(ERROR) << "The lite api set last error is " << msg; +} + +void *LiteBackendAllocWorkspace(int deviceType, + int deviceId, + uint64_t size, + int dtypeCode, + int dtypeBits) { + std::lock_guard lock(gWorkspaceMutex); + auto p = mindspore::predict::WorkspacePool::GetInstance(); + if (p == nullptr) { + MS_LOG(ERROR) << "Get thread pool instance failed"; + return nullptr; + } + return p->AllocWorkSpaceMem(size); +} + +int LiteBackendFreeWorkspace(int deviceType, int deviceId, void *ptr) { + std::lock_guard lock(gWorkspaceMutex); + auto p = mindspore::predict::WorkspacePool::GetInstance(); + if (p == nullptr) { + return -1; + } + p->FreeWorkSpaceMem(ptr); + return 0; +} + +void SetMaxWokerNum(int num) { + auto p = mindspore::predict::ThreadPool::GetInstance(); + if (p == nullptr) { + MS_LOG(ERROR) << "Get thread pool instance failed"; + return; + } + if (num < 0) { + LiteAPISetLastError("The number of work thread is less than 0"); + return; + } + p->ConfigMaxThreadNum(num); +} + +void ConfigThreadPool(int mode, int nthreads) { + auto p = mindspore::predict::ThreadPool::GetInstance(); + if (p == nullptr) { + MS_LOG(ERROR) << "Get thread pool instance failed"; + return; + } + p->ConfigThreadPool(mode, nthreads); +} + +int LiteBackendParallelLaunch(FTVMParallelLambda flambda, void *cdata, int num_task) { + auto p = mindspore::predict::ThreadPool::GetInstance(); + if (p == nullptr) { + MS_LOG(ERROR) << "Get thread pool instance failed"; + return -1; + } + if (!p->LaunchWork(flambda, cdata, num_task)) { + MS_LOG(ERROR) << "launch thread pool work failed"; + return -1; + } + return 0; +} + +void DoAllThreadBind(bool ifBind, int mode) { + auto p = mindspore::predict::ThreadPool::GetInstance(); + if (p == nullptr) { + MS_LOG(ERROR) << "Get thread pool instance failed"; + return; + } + if (!p->BindAllThreads(ifBind, mode)) { + MS_LOG(ERROR) << "do thread cpu bind failed"; + } +} + +#ifdef __cplusplus +} +#endif + diff --git a/mindspore/lite/src/runtime/runtime_api.h b/mindspore/lite/src/runtime/runtime_api.h new file mode 100644 index 00000000000..cd3942d79e3 --- /dev/null +++ b/mindspore/lite/src/runtime/runtime_api.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_API_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_API_H_ +#include + +#ifndef INTERNAL_API_DLL +#ifdef _WIN32 +#ifdef LITE_EXPORTS +#define INTERNAL_API_DLL __declspec(dllexport) +#else +#define INTERNAL_API_DLL __declspec(dllimport) +#endif +#else +#define INTERNAL_API_DLL __attribute__((visibility("default"))) +#endif +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct { + void *sync_handle; + int32_t num_task; +} LiteParallelGroupEnv; +typedef int (*FTVMParallelLambda)(int task_id, LiteParallelGroupEnv *penv, void *cdata); +INTERNAL_API_DLL void LiteAPISetLastError(const char *msg); +INTERNAL_API_DLL void *LiteBackendAllocWorkspace(int deviceType, int deviceId, uint64_t size, int dtypeCode, + int dtypeBits); +INTERNAL_API_DLL int LiteBackendFreeWorkspace(int deviceType, int deviceId, void *ptr); +INTERNAL_API_DLL void SetMaxWokerNum(int num); +INTERNAL_API_DLL void ConfigThreadPool(int mode, int nthreads); +INTERNAL_API_DLL inline void CfgThreadPool(int nthread) { ConfigThreadPool(-1, nthread); } +INTERNAL_API_DLL int LiteBackendParallelLaunch(FTVMParallelLambda flambda, void *cdata, int num_task); +INTERNAL_API_DLL int LiteBackendRegisterSystemLibSymbol(const char *name, void *ptr); +INTERNAL_API_DLL void DoAllThreadBind(bool ifBind, int mode); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_API_H_ + diff --git a/mindspore/lite/src/runtime/thread_pool.cc b/mindspore/lite/src/runtime/thread_pool.cc new file mode 100644 index 00000000000..e9d9c8f1dc1 --- /dev/null +++ b/mindspore/lite/src/runtime/thread_pool.cc @@ -0,0 +1,456 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/thread_pool.h" +#include +#include "utils/log_adapter.h" +#ifdef MS_COMPILE_IOS +#include +#include +#include +#endif // MS_COMPILE_IOS + +namespace mindspore { +namespace predict { +constexpr int kDefaultBigCount = 2; +constexpr int kDefaultMidCount = 2; +constexpr int kSmallCpuNum = 4; +constexpr int kBigMidCpuNum = 4; +constexpr int kDefaultThreadNum = 1; +static unsigned int kDefaultMaxThreadNums = 8; +static unsigned int localMaxThreadNums = 1; + +bool LiteQueue::Enqueue(ThreadPoolTask *task) { + const int tailIndex = tail.load(std::memory_order_relaxed); + // queue full + auto next = (tailIndex + 1) % kSingleThreadMaxTask; + if (next == head.load(std::memory_order_acquire)) { + return false; + } + buffer[tailIndex] = task; + tail.store(next, std::memory_order_release); + ++taskSize; + return true; +} + +bool LiteQueue::Dequeue(ThreadPoolTask **out) { + if (taskSize == 0) { + return false; + } + // queue empty + const int headIndex = head.load(std::memory_order_relaxed); + if (headIndex == tail.load(std::memory_order_acquire)) { + return false; + } + *out = buffer[headIndex]; + head.store((headIndex + 1) % kSingleThreadMaxTask, std::memory_order_release); + return true; +} + +bool LiteThreadBind::Bind(bool ifBind, int numThreads, bool master) { + if (master) { + if (!BindMasterThread(ifBind, bindModel)) { + MS_LOG(ERROR) << "bind msater thread failed"; + return false; + } + MS_LOG(DEBUG) << "bind master thread successful"; + } + if (numThreads > static_cast(sortedCpuIds.size())) { + MS_LOG(ERROR) << "thread num " << numThreads << " is larger than cores " << static_cast(sortedCpuIds.size()) + << " in the system"; + return true; + } + + if (!BindThreads(ifBind)) { + MS_LOG(ERROR) << "action " << ifBind << " thread failed"; + return false; + } + MS_LOG(DEBUG) << "action " << ifBind << " thread successful"; + return true; +} + +void LiteThreadBind::InitSortedCpuId() { + // mate10(970)|p20(970): 4big, 4small + // mate20(980)|p30(980)|mate30(990): 2big, 2mid, 4small + // note: p30's core 7 not allowed to be bind + int numCores = 0; +#ifdef MS_COMPILE_IOS + size_t len = sizeof(numCores); + sysctlbyname("hw.ncpu", &numCores, &len, NULL, 0); + numCores = numCores > 1 ? numCores : 1; +#else + numCores = static_cast(std::thread::hardware_concurrency()); +#endif // MS_COMPILE_IOS + if (numCores < kBigMidCpuNum) { + bigCore = 0; + midCore = numCores; + } else { + bigCore = kDefaultBigCount; + midCore = kDefaultMidCount; + } + sortedCpuIds.clear(); + for (int i = numCores - 1; i >= 0; --i) { + sortedCpuIds.emplace_back(i); + } + if (sortedCpuIds.size() > kSmallCpuNum) { + sortedCpuIds.resize(bigCore + midCore); + } +} + +bool LiteThreadBind::BindMasterThread(bool bindFlag, int mode) { + std::vector cpu; + if (bindFlag) { + size_t cpuIndex; + if (mode == MID_CORE) { + cpuIndex = sortedCpuIds.size() - 1; + } else { + cpuIndex = 0; + } + cpu.emplace_back(sortedCpuIds[cpuIndex]); + } else { + // unbind master + cpu.assign(sortedCpuIds.begin(), sortedCpuIds.end()); + } + cpu_set_t cpuSet; +#ifndef CPU_SET + (void)memset(&cpuSet, 0, sizeof(cpu_set_t)); +#else + CPU_ZERO(&cpuSet); +#endif + for (auto coreId : cpu) { +#ifndef CPU_SET + CPU_SET_LOCAL(coreId, &cpuSet); +#else + CPU_SET(coreId, &cpuSet); +#endif + } + if (!SetCPUBind(pthread_self(), &cpuSet)) { + MS_LOG(ERROR) << "do master bind failed. mode: " << mode; + return false; + } + return true; +} + +bool LiteThreadBind::BindThreads(bool bindFlag) { + if (bindFlag && bindModel != NO_BIND) { + size_t bindNums = std::min(sortedCpuIds.size(), threadIdList.size()); + cpu_set_t cpuSet; + size_t coreIndex; + for (size_t i = 0; i < bindNums; ++i) { +#ifndef CPU_SET + (void)memset(&cpuSet, 0, sizeof(cpu_set_t)); +#else + CPU_ZERO(&cpuSet); +#endif + if (bindModel == MID_CORE) { + coreIndex = sortedCpuIds.size() - 2 - i; + } else { + coreIndex = i + 1; + } +#ifndef CPU_SET + CPU_SET_LOCAL(sortedCpuIds[coreIndex], &cpuSet); +#else + CPU_SET(sortedCpuIds[coreIndex], &cpuSet); +#endif + if (!SetCPUBind(threadIdList[i], &cpuSet)) { + MS_LOG(ERROR) << "do SetCPUBind failed"; + return false; + } + } + } else { + // unbind + size_t bindNums = std::min(sortedCpuIds.size(), threadIdList.size()); + cpu_set_t cpuSet; +#ifndef CPU_SET + (void)memset(&cpuSet, 0, sizeof(cpu_set_t)); +#else + CPU_ZERO(&cpuSet); +#endif + for (auto coreId : sortedCpuIds) { +#ifndef CPU_SET + CPU_SET_LOCAL(coreId, &cpuSet); +#else + CPU_SET(coreId, &cpuSet); +#endif + } + for (size_t i = 0; i < bindNums; ++i) { + if (!SetCPUBind(threadIdList[i], &cpuSet)) { + MS_LOG(ERROR) << "do SetCPUBind failed"; + return false; + } + } + } + return true; +} + +bool LiteThreadBind::SetCPUBind(pthread_t threadId, cpu_set_t *cpuSet) { +#if defined(__ANDROID__) +#if __ANDROID_API__ >= 21 + int ret = sched_setaffinity(pthread_gettid_np(threadId), sizeof(cpu_set_t), cpuSet); + if (ret != 0) { + MS_LOG(ERROR) << "bind thread %ld to cpu failed.ERROR %d", threadId, ret; + } +#endif +#else +#ifdef __APPLE__ + MS_LOG(ERROR) << "not bind thread to apple's cpu."; + return false; +#else + int ret = pthread_setaffinity_np(threadId, sizeof(cpuSet), cpuSet); + if (ret != 0) { + MS_LOG(ERROR) << "bind thread " << threadId << " to cpu failed.ERROR " << ret; + return false; + } +#endif // __APPLE__ +#endif + return true; +} + +bool ThreadPool::SetThreadPool() { + std::lock_guard Lock(poolMutex); + if (configThreadNums <= 0) { + MS_LOG(WARNING) << "numThreads " << configThreadNums << ", must be greater than 0"; + configThreadNums = curThreadRunNums; + } + if (localMaxThreadNums == 0) { + localMaxThreadNums = 1; + } else if (localMaxThreadNums > kDefaultMaxThreadNums) { + localMaxThreadNums = kDefaultMaxThreadNums; + } + if (configThreadNums > kDefaultMaxThreadNums) { + configThreadNums = kDefaultMaxThreadNums; + } + int addNum = 0; + if (configThreadNums > kDefaultMaxThreadNums) { + addNum = configThreadNums - curThreadRunNums; + } else if (localMaxThreadNums > curThreadNums) { + addNum = localMaxThreadNums - curThreadNums; + } + AddNewThread(addNum); + if (curThreadRunNums > localMaxThreadNums) { + SubRunThread(localMaxThreadNums); + } else { + AddRunThread(localMaxThreadNums); + } + MS_LOG(DEBUG) << "configThreadNums=" << configThreadNums << ", curThreadNums=" << curThreadNums + << ", curThreadRunNums=" << curThreadRunNums << ", localMaxThreadNums=" << localMaxThreadNums; + return true; +} + +void ThreadPool::AddNewThread(int newNums) { + for (int i = curThreadNums - 1, j = 0; j < newNums; ++i, ++j) { + auto active = new std::atomic_bool{true}; + auto queue = std::make_shared(); + threadList.emplace_back([this, i, active, queue]() { + ThreadPoolTask *task = nullptr; + while (!exitRun) { + while (*active) { + if (queue->Dequeue(&task)) { + auto ret = task->first(i + 1, task->second.tvmParam, task->second.cdata); + if (ret != 0) { + errorInfo.emplace_back(std::make_pair(i + 1, std::make_pair(false, ret))); + } + queue->taskSize--; + } + std::this_thread::yield(); + } + std::unique_lock queueLock(tMutex); + queueReady.wait(queueLock, [active, this] { return exitRun || *active; }); + } + }); + activateList.emplace_back(active); + queueList.emplace_back(queue); + } + curThreadNums += newNums; + curThreadRunNums += newNums; + MS_LOG(DEBUG) << "add " << newNums << " thread"; +} + +bool ThreadPool::SetThreadCpuBind(bool ifBind, int mode, bool master) { + if (curThreadRunNums <= 0) { + MS_LOG(ERROR) << "no threads need to be bind, totalThreadNum : " << curThreadRunNums; + return false; + } + if (threadBind == nullptr) { + threadBind = std::unique_ptr(new LiteThreadBind()); + if (threadBind == nullptr) { + MS_LOG(ERROR) << "create threadBind failed"; + return false; + } + threadBind->threadIdList.resize(kDefaultMaxThreadNums); + threadBind->InitSortedCpuId(); + } + threadBind->threadIdList.clear(); + for (auto &it : threadList) { + threadBind->threadIdList.emplace_back(it.native_handle()); + } + threadBind->bindModel = static_cast(mode); + if (!threadBind->Bind(ifBind, curThreadRunNums, master)) { + MS_LOG(ERROR) << "bind failed"; + return false; + } + return true; +} + +bool ThreadPool::AddTask(WorkFun &&worker, void *cdata, int numTask) { + if (numTask <= 0) { + numTask = curThreadRunNums; + } + TvmEnv env{}; + env.num_task = numTask; + errorInfo.clear(); + // single task, run master thread + if (curThreadRunNums <= 1) { + for (int i = 0; i < numTask; ++i) { + int ret = worker(i, &env, cdata); + if (ret != 0) { + errorInfo.emplace_back(std::make_pair(0, std::make_pair(false, ret))); + } + } + return CheckResult(); + } + ThreadPoolTask task; + task.first = std::move(worker); + task.second.cdata = cdata; + task.second.tvmParam = &env; + return DistributeTask(&task, numTask); +} + +bool ThreadPool::DistributeTask(ThreadPoolTask *task, int numTask) { + MS_LOG(DEBUG) << "numTask = " << numTask << ", curThreadRunNums = " << curThreadRunNums; + auto taskOri = *task; + if (numTask > curThreadRunNums) { + task->first = [taskOri, numTask, this](int task_id, TvmEnv *penv, void *cdata) -> int { + for (int i = task_id; i < numTask; i += curThreadRunNums) { + int ret = taskOri.first(i, penv, cdata); + if (ret != 0) { + errorInfo.emplace_back(std::make_pair(i + 1, std::make_pair(false, ret))); + } + } + return 0; + }; + } + bool kSuccFlag; + auto size = std::min(curThreadRunNums, numTask); + for (int i = 0; i < size - 1; ++i) { + do { + kSuccFlag = true; + if (!queueList[i]->Enqueue(task)) { + std::this_thread::yield(); + kSuccFlag = false; + } + } while (!kSuccFlag); + } + // master thread + int ret = task->first(0, task->second.tvmParam, task->second.cdata); + if (ret != 0) { + errorInfo.emplace_back(std::make_pair(0, std::make_pair(false, ret))); + } + kSuccFlag = false; + while (!kSuccFlag) { + std::this_thread::yield(); + kSuccFlag = true; + for (int i = 0; i < curThreadRunNums - 1; ++i) { + if (queueList[i]->taskSize != 0) { + kSuccFlag = false; + break; + } + } + } + MS_LOG(DEBUG) << "finish " << numTask << " task successful"; + return CheckResult(); +} + +void ThreadPool::AddRunThread(int num) { + MS_LOG(DEBUG) << "num=" << num << ", curThreadRunNums=" << curThreadRunNums; + int activeNums = num - curThreadRunNums; + if (activeNums <= 0 || activateList.size() < activeNums) { + return; + } + for (int i = curThreadRunNums - 1, j = 0; j < activeNums; ++i, ++j) { + *activateList[i] = true; + } + std::lock_guard queueLock(tMutex); + queueReady.notify_all(); + curThreadRunNums = num; +} + +void ThreadPool::SubRunThread(int num) { + MS_LOG(DEBUG) << "num=" << num << ", curThreadRunNums=" << curThreadRunNums; + int deactiveNums = curThreadRunNums - num; + if (deactiveNums <= 0) { + return; + } + for (int i = num - 1, j = 0; j < deactiveNums; ++i, ++j) { + *activateList[i] = false; + } + curThreadRunNums = num; +} + +bool ThreadPool::CheckResult() { + bool kSuccFlag = true; + for (auto result : errorInfo) { + if (result.second.first) { + MS_LOG(ERROR) << "task " << result.first << " failed, error code is " << result.second.second; + kSuccFlag = false; + } + } + return kSuccFlag; +} + +bool ThreadPool::LaunchWork(WorkFun worker, void *cdata, int numTask) { + if (!SetThreadPool()) { + return false; + } + return AddTask(std::move(worker), cdata, numTask); +} + +bool ThreadPool::BindAllThreads(bool ifBind, int mode, bool master) { + if (!SetThreadPool()) { + return false; + } + return SetThreadCpuBind(ifBind, mode, master); +} + +void ThreadPool::ConfigThreadPool(int mode, int numThreads) { + configBindMode = mode; + configThreadNums = numThreads; +} + +void ThreadPool::ConfigMaxThreadNum(unsigned int num) { localMaxThreadNums = num; } + +ThreadPool *ThreadPool::GetInstance() { + static ThreadPool instance; + return &instance; +} + +ThreadPool::~ThreadPool() { + curThreadRunNums = static_cast(threadList.size() + 1); + exitRun = true; + SubRunThread(kDefaultThreadNum); + queueReady.notify_all(); + for (auto &it : threadList) { + if (it.joinable()) { + it.join(); + } + } + for (const auto &it : activateList) { + delete it; + } +} +} // namespace predict +} // namespace mindspore + diff --git a/mindspore/lite/src/runtime/thread_pool.h b/mindspore/lite/src/runtime/thread_pool.h new file mode 100644 index 00000000000..f9a26bff65c --- /dev/null +++ b/mindspore/lite/src/runtime/thread_pool.h @@ -0,0 +1,126 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_THREAD_POOL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_THREAD_POOL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "src/runtime/runtime_api.h" + +namespace mindspore { +namespace predict { +#ifndef CPU_SET +const int CPU_SETSIZE = 1024; +#define __NCPUBITS (8 * sizeof(uint64_t)) +typedef struct { + uint64_t __bits[CPU_SETSIZE / __NCPUBITS]; +} cpu_set_t; + +#define CPU_SET_LOCAL(cpu, cpusetp) ((cpusetp)->__bits[(cpu) / __NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS))) +#endif + +constexpr int kSingleThreadMaxTask = 2; +using TvmEnv = LiteParallelGroupEnv; +using WorkFun = std::function; +using TaskParam = struct Param { + void *cdata; + TvmEnv *tvmParam; +}; +using ThreadPoolTask = std::pair; +enum AffinityMode : int { BIG_CORE = 1, MID_CORE = -1, NO_BIND = 0 }; + +class LiteQueue { + public: + LiteQueue() = default; + ~LiteQueue() = default; + bool Enqueue(ThreadPoolTask *task); + bool Dequeue(ThreadPoolTask **out); + std::atomic_int taskSize = {0}; + + private: + std::atomic_int head = {0}; + std::atomic_int tail = {0}; + ThreadPoolTask *buffer[kSingleThreadMaxTask]{}; +}; + +class LiteThreadBind { + public: + LiteThreadBind() = default; + ~LiteThreadBind() = default; + void InitSortedCpuId(); + bool Bind(bool ifBind, int numThreads, bool master); + AffinityMode bindModel = MID_CORE; + std::vector threadIdList; + + private: + bool BindMasterThread(bool bindFlag, int mode); + bool BindThreads(bool bindFlag); + bool SetCPUBind(pthread_t threadId, cpu_set_t *cpuSet); + int bigCore = 0; + int midCore = 0; + std::vector sortedCpuIds{}; +}; + +class ThreadPool { + public: + ThreadPool() = default; + ~ThreadPool(); + static ThreadPool *GetInstance(); + bool LaunchWork(WorkFun worker, void *cdata, int numTask); + void ConfigThreadPool(int mode, int numThreads); + void ConfigMaxThreadNum(unsigned int num); + bool BindAllThreads(bool ifBind, int mode, bool master = true); + ThreadPool(const ThreadPool &) = delete; + ThreadPool &operator=(const ThreadPool &) = delete; + + private: + bool SetThreadPool(); + void AddNewThread(int newNums); + bool SetThreadCpuBind(bool ifBind, int mode, bool master); + bool AddTask(WorkFun &&worker, void *cdata, int numTask); + bool DistributeTask(ThreadPoolTask *task, int numTask); + void AddRunThread(int num); + void SubRunThread(int num); + bool CheckResult(); + + std::mutex poolMutex; + std::mutex tMutex; + std::condition_variable queueReady; + std::atomic_bool exitRun = {false}; + std::vector activateList{}; + int curThreadNums = 1; + int curThreadRunNums = 1; + int configThreadNums = 1; + int configBindMode = -1; + std::vector threadList{}; + std::vector> queueList{}; + std::unique_ptr threadBind{nullptr}; + std::vector>> errorInfo{}; +}; +} // namespace predict +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_RUNTIME_THREAD_POOL_H_ + diff --git a/mindspore/lite/src/runtime/workspace_pool.cc b/mindspore/lite/src/runtime/workspace_pool.cc new file mode 100644 index 00000000000..b1cd76cb1d3 --- /dev/null +++ b/mindspore/lite/src/runtime/workspace_pool.cc @@ -0,0 +1,143 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/workspace_pool.h" +#ifdef __APPLE__ +#include +#else +#include +#endif +#include +#include "utils/log_adapter.h" + +namespace mindspore { +namespace predict { +static constexpr size_t kWorkspacePageSize = 4096; +static constexpr int kTempAllocaAlignment = 64; +WorkspacePool *WorkspacePool::GetInstance() { + static WorkspacePool instance; + return &instance; +} + +void *WorkspacePool::AllocWorkSpaceMem(size_t size) { + size_t nbytes = (size + (kWorkspacePageSize - 1)) / kWorkspacePageSize * kWorkspacePageSize; + if (nbytes == 0) { + nbytes = kWorkspacePageSize; + } + std::pair alloc; + // fist alloc + if (freeList.empty()) { + alloc.first = nbytes; +#ifdef __APPLE__ + int err = posix_memalign(&alloc.second, kTempAllocaAlignment, nbytes); + if (err != 0) { + MS_LOGE("posix_memalign failed, error code:%d", err); + return alloc.second; + } +#else + alloc.second = memalign(kTempAllocaAlignment, nbytes); +#endif + } else if (freeList.size() == 1) { // one element + alloc = *(freeList.begin()); + freeList.erase(freeList.begin()); + if (alloc.first < nbytes) { + free(alloc.second); + alloc.first = nbytes; +#ifdef __APPLE__ + int err = posix_memalign(&alloc.second, kTempAllocaAlignment, nbytes); + if (err != 0) { + MS_LOGE("posix_memalign failed, error code:%d", err); + return alloc.second; + } +#else + alloc.second = memalign(kTempAllocaAlignment, nbytes); +#endif + } + } else { + if ((*(freeList.begin())).first >= nbytes) { + auto iter = freeList.begin(); + for (; iter != freeList.end(); ++iter) { + if ((*iter).first < size) { + alloc = *(--iter); + freeList.erase(iter); + break; + } + } + if (iter == freeList.end()) { + alloc = *(freeList.rbegin()); + freeList.erase(--freeList.end()); + } + } else { + alloc = *(freeList.begin()); + freeList.erase(freeList.begin()); + free(alloc.second); + alloc.first = nbytes; +#ifdef __APPLE__ + int err = posix_memalign(&alloc.second, kTempAllocaAlignment, nbytes); + if (err != 0) { + MS_LOGE("posix_memalign failed, error code:%d", err); + return alloc.second; + } +#else + alloc.second = memalign(kTempAllocaAlignment, nbytes); +#endif + } + } + allocList.emplace_back(alloc); + return alloc.second; +} + +void WorkspacePool::FreeWorkSpaceMem(void *ptr) { + if (ptr == nullptr) { + return; + } + std::pair alloc; + if (allocList.empty()) { + MS_LOG(ERROR) << "no mem have been alloc"; + return; + } else if (allocList.back().second == ptr) { + alloc = allocList.back(); + allocList.pop_back(); + } else { + auto iter = allocList.begin(); + for (; iter != allocList.end(); ++iter) { + if ((*iter).second == ptr) { + alloc = *iter; + allocList.erase(iter); + break; + } + } + if (iter == allocList.end()) { + MS_LOG(ERROR) << "no value ptr have been alloc"; + return; + } + } + freeList.insert(alloc); +} + +WorkspacePool::~WorkspacePool() { + for (auto &a : allocList) { + free(a.second); + } + allocList.clear(); + for (auto &f : freeList) { + free(f.second); + } + freeList.clear(); +} +} // namespace predict +} // namespace mindspore + diff --git a/mindspore/lite/src/runtime/workspace_pool.h b/mindspore/lite/src/runtime/workspace_pool.h new file mode 100644 index 00000000000..9342200b28b --- /dev/null +++ b/mindspore/lite/src/runtime/workspace_pool.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_WORKSPACE_POOL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_WORKSPACE_POOL_H_ +#include +#include +#include +#include +#include +#include + +namespace mindspore { +namespace predict { +class WorkspacePool { + public: + WorkspacePool() = default; + ~WorkspacePool(); + WorkspacePool(const WorkspacePool &) = delete; + WorkspacePool &operator=(const WorkspacePool &) = delete; + static WorkspacePool *GetInstance(); + void *AllocWorkSpaceMem(size_t size); + void FreeWorkSpaceMem(void *ptr); + + private: + std::vector> allocList{}; + std::set, std::greater>> freeList{}; +}; +} // namespace predict +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_RUNTIME_WORKSPACE_POOL_H_ + diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc new file mode 100644 index 00000000000..b207840b0ea --- /dev/null +++ b/mindspore/lite/src/scheduler.cc @@ -0,0 +1,173 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/scheduler.h" +#include +#include +#include "include/errorcode.h" +#include "src/kernel_factory.h" +#if SUPPORT_GPU +#include "src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#endif + +namespace mindspore::lite { +int Scheduler::Schedule(const lite::Model *model, std::vector *tensors, + std::vector *kernels) { + // 1. op ---> kernel + // 2. sub graph + // 3. kernels (kernels --> subGraph) + int ret = InitOp2Kernel(model, tensors, kernels); + if (ret != RET_OK) { + MS_LOG(ERROR) << "init op to kernel failed."; + return RET_ERROR; + } + + kernel::LiteKernelUtil::TopologicalSortKernels(*kernels); + + ConstructSubgraphs(kernels); + + MS_LOG(DEBUG) << "schedule kernels success."; + return RET_OK; +} + +int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector *tensors, + std::vector *kernels) { + MS_EXCEPTION_IF_NULL(model); + MS_EXCEPTION_IF_NULL(tensors); + MS_EXCEPTION_IF_NULL(kernels); + auto meta_graph = model->GetMetaGraph(); + MS_EXCEPTION_IF_NULL(meta_graph); + uint32_t kernelCount = meta_graph->nodes()->size(); + for (uint32_t i = 0; i < kernelCount; i++) { + auto cNode = meta_graph->nodes()->GetAs(i); + std::vector inputs; + std::vector outputs; + auto inIndexes = cNode->inputIndex(); + for (size_t j = 0; j < inIndexes->size(); j++) { + inputs.emplace_back(tensors->at(size_t(inIndexes->GetAs(j)))); + } + auto outIndexes = cNode->outputIndex(); + for (size_t j = 0; j < outIndexes->size(); j++) { + outputs.emplace_back(tensors->at(size_t(outIndexes->GetAs(j)))); + } + auto *primitive = model->GetOp(cNode->name()->str()); + if (primitive == nullptr) { + MS_LOG(ERROR) << "Op " << cNode->name()->str() << " should exist in model, type: " + << schema::EnumNamePrimitiveType(cNode->primitive()->value_type()); + return RET_ERROR; + } + auto ret = primitive->InferShape(inputs, outputs); + if (0 != ret) { + MS_LOG(ERROR) << "InferShape failed, name: " << cNode->name()->str() + << ", type: " << schema::EnumNamePrimitiveType(cNode->primitive()->value_type()); + return ret; + } + + auto *kernel = this->ScheduleNode(inputs, outputs, primitive); + if (nullptr == kernel) { + MS_LOG(ERROR) << "ScheduleNode return nullptr, name: " << cNode->name()->str() + << ", type: " << schema::EnumNamePrimitiveType(cNode->primitive()->value_type()); + return RET_ERROR; + } + kernel->set_name(cNode->name()->str()); + kernels->emplace_back(kernel); + } + return RET_OK; +} + +void Scheduler::ConstructSubgraphs(std::vector *kernels) { + uint32_t kernel_count = kernels->size(); + std::vector sub_kernels; + std::vector> sub_kernels_list; + + kernel::KERNEL_ARCH prev_arch = kernels->front()->Desc().arch; + for (uint32_t i = 0; i < kernel_count; ++i) { + auto curr_kernel = kernels->at(i); + auto curr_arch = curr_kernel->Desc().arch; + if (curr_arch == prev_arch) { + sub_kernels.emplace_back(curr_kernel); + } + if ((curr_arch != prev_arch) || (i == kernel_count - 1)) { + sub_kernels_list.emplace_back(sub_kernels); + sub_kernels.clear(); + sub_kernels.emplace_back(curr_kernel); + } + prev_arch = curr_arch; + } + + std::vector subgraph_kernels; + for (auto temp_kernels : sub_kernels_list) { + kernel::KERNEL_ARCH arch = temp_kernels.front()->Desc().arch; + if (arch == kernel::KERNEL_ARCH::kCPU) { + std::copy(temp_kernels.begin(), temp_kernels.end(), std::back_inserter(subgraph_kernels)); + } else { + auto subgraph_kernel = CreateSubKernel(temp_kernels, arch); + subgraph_kernels.emplace_back(subgraph_kernel); + } + } + kernels->clear(); + kernels->insert(kernels->begin(), subgraph_kernels.begin(), subgraph_kernels.end()); +} + +kernel::LiteKernel *Scheduler::CreateSubKernel(const std::vector &kernels, + kernel::KERNEL_ARCH arch) { + kernel::LiteKernel *sub_kernel = nullptr; +#if SUPPORT_GPU + if (arch == kernel::KERNEL_ARCH::kGPU) { + std::vector input_tensors = kernel::LiteKernelUtil::SubgraphInputTensors(kernels); + std::vector output_tensors = kernel::LiteKernelUtil::SubgraphOutputTensors(kernels); + std::vector input_kernels = kernel::LiteKernelUtil::SubgraphInputKernels(kernels); + std::vector output_kernels = kernel::LiteKernelUtil::SubgraphOutputKernels(kernels); + sub_kernel = + new kernel::SubGraphOpenCLKernel(input_tensors, output_tensors, input_kernels, output_kernels, kernels); + sub_kernel->Init(); + } else if (arch == kernel::KERNEL_ARCH::kNPU) { + MS_LOG(ERROR) << "NPU kernel is not supported"; + } else { + MS_LOG(ERROR) << "unsupported kernel arch: " << arch; + } +#endif + return sub_kernel; +} + +int Scheduler::MarkKernels(const std::vector &kernels) { return 0; } + +int Scheduler::MergeKernels(std::vector *kernels) { return 0; } + +kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector &inputs, + const std::vector &outputs, + const lite::Primitive *primitive) { + // todo: support CPU, NPU, APU + MS_ASSERT(nullptr != primitive); + kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, primitive->Type()}; + if (context->deviceCtx.type == DT_GPU) { + desc.arch = kernel::KERNEL_ARCH::kGPU; + auto *kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context, desc); + if (nullptr != kernel) { + kernel->set_desc(desc); + return kernel; + } + } + desc.arch = kernel::KERNEL_ARCH::kCPU; + auto *kernel = KernelFactory::GetInstance()->GetKernel(inputs, outputs, primitive, context, desc); + if (nullptr != kernel) { + kernel->set_desc(desc); + return kernel; + } + return nullptr; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/scheduler.h b/mindspore/lite/src/scheduler.h new file mode 100644 index 00000000000..6d0d98e1dfa --- /dev/null +++ b/mindspore/lite/src/scheduler.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_SCHEDULER_H_ +#define MINDSPORE_LITE_SRC_SCHEDULER_H_ + +#include +#include "src/lite_kernel.h" +#include "include/context.h" +#include "include/model.h" + +namespace mindspore::lite { +class Scheduler { + public: + explicit Scheduler(const Context *ctx) : context(ctx) {} + int Schedule(const lite::Model *model, std::vector *tensors, + std::vector *kernels); + + protected: + kernel::LiteKernel *ScheduleNode(const std::vector &inputs, + const std::vector &outputs, const lite::Primitive *primitive); + // find schedule able kernels and save in markedKernelGroup + int MarkKernels(const std::vector &kernels); + // use SubGraphKernel to replace group in kernels + int MergeKernels(std::vector *kernels); + + private: + int InitOp2Kernel(const lite::Model *model, std::vector *tensors, + std::vector *kernels); + + // construct SubGraphKernel for each kernel-group in markedKernelGroup + void ConstructSubgraphs(std::vector *kernels); + + kernel::LiteKernel *CreateSubKernel(const std::vector &kernels, kernel::KERNEL_ARCH arch); + + protected: + std::vector> markedKernelGroup; + const Context *context = nullptr; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_SCHEDULER_H_ diff --git a/mindspore/lite/src/train/base_ref_utils.cc b/mindspore/lite/src/train/base_ref_utils.cc new file mode 100644 index 00000000000..5df16a95520 --- /dev/null +++ b/mindspore/lite/src/train/base_ref_utils.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/train/base_ref_utils.h" +#include +#include +// #include "utils/base_ref_utils.h" +#include "include/ms_tensor.h" +#include "src/ir/tensor.h" + +namespace mindspore { +std::vector> TransformBaseRefToMSTensor(const BaseRef &base_ref) { + std::vector> msTensors; + if (utils::isa(base_ref)) { + auto ref_list = utils::cast(base_ref); + for (size_t i = 0; i < ref_list.size(); ++i) { + if (utils::isa(ref_list[i])) { + auto tensor_ptr = utils::cast>(ref_list[i]); + MS_EXCEPTION_IF_NULL(tensor_ptr); + auto tensor = new tensor::LiteTensor(new tensor::Tensor(*tensor_ptr)); + msTensors.emplace_back(std::shared_ptr(tensor)); + } else { + MS_LOG(EXCEPTION) << "The output is not a tensor!"; + } + } + } else if (utils::isa(base_ref)) { + auto tensor_ptr = utils::cast>(base_ref); + MS_EXCEPTION_IF_NULL(tensor_ptr); + auto tensor = new tensor::LiteTensor(new tensor::Tensor(*tensor_ptr)); + msTensors.emplace_back(std::shared_ptr(tensor)); + } else { + MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; + } + return msTensors; +} + +std::vector>> TransformVectorRefToMultiTensor( + const VectorRef &vector_ref) { + std::vector>> multiTensor; + for (size_t i = 0; i < vector_ref.size(); ++i) { + auto tensors = TransformBaseRefToMSTensor(vector_ref[i]); + multiTensor.emplace_back(tensors); + } + return multiTensor; +} +} // namespace mindspore + diff --git a/mindspore/lite/src/train/base_ref_utils.h b/mindspore/lite/src/train/base_ref_utils.h new file mode 100644 index 00000000000..63370efeb93 --- /dev/null +++ b/mindspore/lite/src/train/base_ref_utils.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "base/base_ref.h" +#include "include/ms_tensor.h" + +#ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H +#define MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H +namespace mindspore { +std::vector> TransformBaseRefToMSTensor(const BaseRef &base_ref); + +std::vector>> TransformVectorRefToMultiTensor( + const VectorRef &vector_ref); +} // namespace mindspore +#endif // MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H + diff --git a/mindspore/lite/src/train/import.hpp b/mindspore/lite/src/train/import.hpp new file mode 100644 index 00000000000..e8153beacb3 --- /dev/null +++ b/mindspore/lite/src/train/import.hpp @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/common/anf_importer/import_from_meta_graph.h" +namespace mindspore::lite::train { +std::shared_ptr Import(const char *model_buf, size_t size) { + MS_EXCEPTION_IF_NULL(model_buf); + flatbuffers::Verifier verify((const uint8_t *) model_buf, size); + if (!schema::VerifyMetaGraphBuffer(verify)) { + MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; + return nullptr; + } + // todo hangangqiang remove when copy primitive done + auto *inner_buf = new char[size]; + memcpy(inner_buf, model_buf, size); + auto meta_graph = schema::GetMetaGraph(inner_buf); + auto model = std::make_shared(meta_graph); + auto ret = model->BuildOps(); + if (0 != ret) { + MS_LOG(ERROR) << "BuildOps failed"; + return nullptr; + } + MS_EXCEPTION_IF_NULL(meta_graph); + auto importer = new AnfImporterFromMetaGraph(model); + auto ret2 = importer->Import(); + if (0 != ret2) { + MS_LOG(ERROR) << "Import anf_graph from meta_graph failed, ret2: " << ret2; + return nullptr; + } + return model; +} +} // namespace mindspore::lite::train diff --git a/mindspore/lite/src/train/lite_kernel_runtime.cc b/mindspore/lite/src/train/lite_kernel_runtime.cc new file mode 100644 index 00000000000..e1eb95d4420 --- /dev/null +++ b/mindspore/lite/src/train/lite_kernel_runtime.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/src/train/lite_kernel_runtime.h" +namespace mindspore::lite { +std::vector LiteInferKernelRuntime::GetGraphInputs(const std::vector &execution_order) { + std::vector graph_inputs; + for (const auto &cnode : execution_order) { + bool is_graph_inputs = true; + for (const auto &input : cnode->inputs()) { + if (input->isa()) { + is_graph_inputs = false; + break; + } + } + if (is_graph_inputs) { + graph_inputs.emplace_back(cnode); + } + } + return graph_inputs; +} + +void LiteInferKernelRuntime::BindInputOutput(const session::KernelGraph *graph, + const std::vector &inputs, VectorRef *outputs) { + MS_EXCEPTION_IF_NULL(graph); + auto execution_order = graph->execution_order(); + auto graph_inputs = GetGraphInputs(execution_order); + int input_count = 0; + for (const auto &graph_input : graph_inputs) { + auto liteKernel = dynamic_cast(AnfAlgo::GetKernelMod(graph_input)); + for (auto input_tensor : liteKernel->GetInputs()) { + if (schema::NodeType_ValueNode == input_tensor->TensorType() && input_tensor->Data() != nullptr) { + continue; + } + input_tensor->SetData(inputs[input_count]->Data()); + input_count++; + } + } + + auto return_node = graph->get_return(); + for (const auto &return_input : return_node->inputs()) { + if (return_input->isa()) { + auto liteKernel = dynamic_cast(AnfAlgo::GetKernelMod(return_input)); + auto output_tensors = liteKernel->GetOutputs(); + for (auto output_tensor : output_tensors) { + tensor::TensorPtr output_tensor_ptr(output_tensor); + outputs->push_back(output_tensor_ptr); + } + } + } +} + +bool LiteInferKernelRuntime::Run(session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + std::vector kernels; + auto nodes = graph->execution_order(); + for (const auto &node : nodes) { + auto liteKernel = dynamic_cast(AnfAlgo::GetKernelMod(node)); + if (liteKernel == nullptr) { + continue; + } + kernels.emplace_back(liteKernel); + } + kernel::LiteKernelUtil::TopologicalSortKernels(kernels); + Executor executor; + auto ret = executor.Run(kernels); + return 0 == ret; +} +} // namespace mindspore::lite + diff --git a/mindspore/lite/src/train/lite_kernel_runtime.h b/mindspore/lite/src/train/lite_kernel_runtime.h new file mode 100644 index 00000000000..27b4ec867bd --- /dev/null +++ b/mindspore/lite/src/train/lite_kernel_runtime.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_ +#define MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_ + +#include +#include +#include +#include +#include "src/runtime/allocator.h" +#include "src/executor.h" +#include "runtime/device/kernel_runtime.h" +#include "runtime/device/device_address.h" +#include "src/lite_kernel.h" +#include "backend/session/kernel_graph.h" +namespace mindspore::lite { +class LiteInferKernelRuntime : public device::KernelRuntime { + public: + LiteInferKernelRuntime() = default; + ~LiteInferKernelRuntime() override = default; + + bool Init() override { return true; } + + void BindInputOutput(const session::KernelGraph *graph, const std::vector &inputs, + VectorRef *outputs); + + bool Run(session::KernelGraph *graph); + + void AssignKernelAddress(session::KernelGraph *graph) {} + + protected: + std::vector GetGraphInputs(const std::vector &execution_order); + bool SyncStream() override { return true; }; + device::DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) override { + return nullptr; + }; +}; + +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_ + diff --git a/mindspore/lite/src/train/model_impl.cc b/mindspore/lite/src/train/model_impl.cc new file mode 100644 index 00000000000..30d60f77094 --- /dev/null +++ b/mindspore/lite/src/train/model_impl.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/train/model_impl.h" +#include "schema/model_generated.h" +#include "ir/func_graph.h" + +namespace mindspore::lite::train { + +const lite::Primitive *ModelImpl::GetOp(const std::string &name) const { + auto iter = ops.find(name); + if (iter == ops.end()) { + return nullptr; + } else { + return iter->second; + } +} + +void ModelImpl::FreeMetaGraph() { delete this->meta_graph; } + +const schema::MetaGraph *ModelImpl::GetMetaGraph() const { return this->meta_graph; } + +lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { + MS_EXCEPTION_IF_NULL(srcPrim); + auto op_type = srcPrim->value_type(); + switch (op_type) { + case schema::PrimitiveType_SoftMax: + return new lite::SoftMax(const_cast(srcPrim)); + case schema::PrimitiveType_Activation: + return new lite::Activation(const_cast(srcPrim)); + case schema::PrimitiveType_Conv2D: + return new lite::Conv2D(const_cast(srcPrim)); + case schema::PrimitiveType_Reduce: + return new lite::Reduce(const_cast(srcPrim)); + case schema::PrimitiveType_Pooling: + return new lite::Pooling(const_cast(srcPrim)); + case schema::PrimitiveType_DepthwiseConv2D: + return new lite::DepthwiseConv2D(const_cast(srcPrim)); + case schema::PrimitiveType_FusedBatchNorm: + return new lite::FusedBatchNorm(const_cast(srcPrim)); + case schema::PrimitiveType_CaffeBatchNorm: + return new lite::CaffeBatchNorm(const_cast(srcPrim)); + case schema::PrimitiveType_FullConnection: + return new lite::FullConnection(const_cast(srcPrim)); + case schema::PrimitiveType_Power: + return new lite::Power(const_cast(srcPrim)); + case schema::PrimitiveType_Range: + return new lite::Range(const_cast(srcPrim)); + case schema::PrimitiveType_Mul: + return new lite::Mul(const_cast(srcPrim)); + case schema::PrimitiveType_Add: + return new lite::Add(const_cast(srcPrim)); + case schema::PrimitiveType_Sub: + return new lite::Sub(const_cast(srcPrim)); + case schema::PrimitiveType_Div: + return new lite::Div(const_cast(srcPrim)); + case schema::PrimitiveType_BiasAdd: + return new lite::BiasAdd(const_cast(srcPrim)); + case schema::PrimitiveType_ExpandDims: + return new lite::ExpandDims(const_cast(srcPrim)); + case schema::PrimitiveType_ArgMax: + return new lite::ArgMax(const_cast(srcPrim)); + case schema::PrimitiveType_ArgMin: + return new lite::ArgMin(const_cast(srcPrim)); + case schema::PrimitiveType_Cast: + return new lite::Cast(const_cast(srcPrim)); + case schema::PrimitiveType_Reshape: + return new lite::Reshape(const_cast(srcPrim)); + case schema::PrimitiveType_Scale: + return new lite::Scale(const_cast(srcPrim)); + case schema::PrimitiveType_Eltwise: + return new lite::Eltwise(const_cast(srcPrim)); + case schema::PrimitiveType_Ceil: + return new lite::Ceil(const_cast(srcPrim)); + case schema::PrimitiveType_Concat: + return new lite::Concat(const_cast(srcPrim)); + case schema::PrimitiveType_Fill: + return new lite::Fill(const_cast(srcPrim)); + case schema::PrimitiveType_Transpose: + return new lite::Transpose(const_cast(srcPrim)); + case schema::PrimitiveType_Slice: + return new lite::Slice(const_cast(srcPrim)); + case schema::PrimitiveType_Nchw2Nhwc: + return new lite::Nchw2Nhwc(const_cast(srcPrim)); + case schema::PrimitiveType_Nhwc2Nchw: + return new lite::Nhwc2Nchw(const_cast(srcPrim)); + default: + break; + } + return nullptr; +} + +int ModelImpl::BuildOps() { + if (this->meta_graph == nullptr) { + MS_LOG(ERROR) << "mete_graph is nullptr"; + return -1; + } + for (size_t i = 0; i < meta_graph->nodes()->size(); i++) { + auto cNode = meta_graph->nodes()->GetAs(i); + auto name = cNode->name()->str(); + auto srcPrim = cNode->primitive(); + this->ops[name] = CopyPrimitive(srcPrim); + } +} +} // namespace mindspore::lite::train diff --git a/mindspore/lite/src/train/model_impl.h b/mindspore/lite/src/train/model_impl.h new file mode 100644 index 00000000000..496fed2ac3c --- /dev/null +++ b/mindspore/lite/src/train/model_impl.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_ +#define MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H + +#include +#include +#include +#include "schema/model_generated.h" +#include "src/ops/ops.h" +#include "ir/func_graph.h" + +namespace mindspore::lite { +namespace train { +class ModelImpl : public FuncGraph { + public: + static std::shared_ptr Import(const char *model_buf, size_t size); + ModelImpl() = default; + explicit ModelImpl(const schema::MetaGraph *graph) : meta_graph(graph) {} + ~ModelImpl() override = default; + const lite::Primitive *GetOp(const std::string &name) const; + const schema::MetaGraph *GetMetaGraph() const; + void FreeMetaGraph(); + int BuildOps(); + + protected: + lite::Primitive *CopyPrimitive(const schema::Primitive *srcPrim); + + protected: + const schema::MetaGraph *meta_graph = nullptr; + std::map ops; +}; +} // namespace train +using ModelImpl = mindspore::lite::train::ModelImpl; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_INCLUDE_MODEL_H + diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc new file mode 100644 index 00000000000..8f114a1ea54 --- /dev/null +++ b/mindspore/lite/src/train/train_session.cc @@ -0,0 +1,232 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/src/train/train_session.h" +#include "mindspore/lite/src/kernel_factory.h" +#include "mindspore/lite/src/param_value_lite.h" +#include "common/utils.h" +#include "mindspore/lite/src/ops/ops.h" +#include "ir/anf.h" +#include "mindspore/lite/src/ir/tensor.h" +#include "abstract/abstract_value.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "src/ir/primitive_value.h" + +namespace mindspore { +namespace session { +static std::vector GetAnfNodeOutDims(const AnfNodePtr &anfNodePtr) { + auto nodeAbstract = anfNodePtr->abstract(); + if (nodeAbstract != nullptr) { + auto shape = nodeAbstract->GetShapeTrack(); + if (!shape->isa()) { + MS_LOG(EXCEPTION) << "Not a Shape"; + return {}; + } + auto dims = dyn_cast(shape)->shape(); + return dims; + } else { + MS_LOG(WARNING) << "abstract is nullptr, return empty dims"; + return {}; + } +} + +static schema::Format GetAnfNodeFormat(const AnfNodePtr &anfNodePtr) { + auto nodeAbstract = anfNodePtr->abstract(); + if (nodeAbstract != nullptr) { + return schema::Format_NHWC; // XXX TODO -- extract Format from AnfNode + } else { + MS_LOG(WARNING) << "abstract is nullptr, return schema::Format_NHWC"; + return schema::Format_NHWC; + } +} + +static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) { + auto nodeAbstract = anfNodePtr->abstract(); + if (nodeAbstract != nullptr) { + return nodeAbstract->GetTypeTrack()->type_id(); + } else { + MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown"; + return TypeId::kTypeUnknown; + } +} + +int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) { + auto return_node = kernel_graph->get_return(); + auto node_list = TopoSort(return_node); + for (auto &node : node_list) { + if (!node->isa()) { + continue; + } + KernelRelation kernel_relation; + auto cnode = node->cast(); + kernel_relation.node_full_name = cnode->fullname_with_scope(); + kernel_relation.cnode = cnode; + auto *out_tensor = + new tensor::Tensor(GetAnfNodeOutTypeId(cnode), GetAnfNodeOutDims(cnode), GetAnfNodeFormat(cnode), + schema::NodeType_Parameter); + kernel_relation.output_tensor.push_back(out_tensor); + tensor::Tensor *tensor_ptr = nullptr; + for (size_t index = 1; index < cnode->inputs().size(); ++index) { + if (cnode->input(index)->isa()) { + auto input_cnode = cnode->input(index)->cast(); + auto input_kernel_relation = kernel_relation_infos_[input_cnode->fullname_with_scope()]; + // todo not support multi-outputs kernel sudo as spilt + tensor_ptr = input_kernel_relation.output_tensor.front(); + } else if (cnode->input(index)->isa()) { + auto input_parameter = cnode->input(index)->cast(); + auto para = input_parameter->default_param(); + auto param_value = std::dynamic_pointer_cast(para); + auto dims = param_value->tensor_shape(); + tensor_ptr = new tensor::Tensor(param_value->tensor_type(), dims, schema::Format_NHWC, + schema::NodeType_ValueNode); // XXX TODO -- extract Format from AnfNode + if (param_value->tensor_size() != 0) { + tensor_ptr->SetData(param_value->tensor_addr()); + } + } else if (cnode->input(index)->isa()) { + auto input_valuenode = cnode->input(index)->cast(); + tensor_ptr = new tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode), GetAnfNodeOutDims(input_valuenode), + schema::Format_NHWC, + schema::NodeType_Parameter); // XXX TODO -- extract Format from AnfNode + // todo(yankai) + } else { + MS_ASSERT(false); + } + kernel_relation.input_tensor.push_back(tensor_ptr); + } + kernel_relation_infos_[cnode->fullname_with_scope()] = kernel_relation; + } + return 0; +} + +GraphId TrainSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { + auto graph_id = graph_sum_; + auto graph = SessionBasic::ConstructKernelGraph(lst, outputs); + MS_EXCEPTION_IF_NULL(graph); + + BuildKernel(graph.get()); + MS_LOG(INFO) << "Assign kernel address"; + runtime_.AssignKernelAddress(graph.get()); + return graph_id; +} + +GraphId TrainSession::CompileGraph(const char *model_buf, size_t size) { return 0; } + +std::shared_ptr TrainSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto graph = NewKernelGraph(); + graph->set_return(func_graph->get_return()); + auto node_list = TopoSort(func_graph->get_return()); + std::vector cnode_order; + for (const auto &node : node_list) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto cn_node = node->cast(); + cnode_order.push_back(cn_node); + } + } + graph->set_execution_order(cnode_order); + return graph; +} + +GraphId TrainSession::CompileGraph(NotNull func_graph) { + auto graph = ConstructKernelGraph(func_graph); + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "Set kernel info"; + SetKernelInfo(graph.get()); + + (void) BuildKernelInputAndOutputFromFuncGraph(graph); + MS_LOG(INFO) << "Build kernel"; + auto ret = BuildKernel(graph.get()); + if (0 != ret) { + MS_LOG(EXCEPTION) << "BuildKernel failed"; + } + + // return the graph id to backend + auto graph_id = graph->graph_id(); + graphs_[graph_id] = graph; + MS_LOG(INFO) << "Compile graph " << graph_id << " success"; + return graph_id; +} + +void TrainSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, + std::vector &outputs) { + auto &kernel_graph = graphs_[graph_id]; + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_LOG(INFO) << "Bind input output address"; + runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); + // auto execution_order = kernel_graph->execution_order(); + // Todo : hangangqiang + // Reorder(&execution_order); + // kernel_graph->set_execution_order(execution_order); + MS_LOG(INFO) << "Run graph start"; + auto ret = runtime_.Run(kernel_graph.get(), (std::vector &) inputs, outputs); + if (!ret) { + MS_LOG(EXCEPTION) << "Run graph failed"; + } + MS_LOG(INFO) << "Run graph end"; +} + +void TrainSession::SetKernelInfo(const KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto &kernel_nodes = kernel_graph->execution_order(); + for (const auto &kernel_node : kernel_nodes) { + MS_EXCEPTION_IF_NULL(kernel_node); + auto kernel_info = std::make_shared(); + kernel_node->set_kernel_info(kernel_info); + } +} + +int TrainSession::BuildKernel(const KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + for (auto iter = kernel_relation_infos_.begin(); iter != kernel_relation_infos_.end(); ++iter) { + std::string kernel_name = iter->first; + KernelRelation anf_register = iter->second; + MS_EXCEPTION_IF_NULL(anf_register.cnode); + if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) { + continue; + } + lite::Context context; + context.deviceCtx.type = lite::DeviceType::DT_CPU; + auto value_node_prim = anf_register.cnode->input(0); + MS_EXCEPTION_IF_NULL(value_node_prim); + auto prim = GetValueNode>(value_node_prim); + MS_EXCEPTION_IF_NULL(prim); + auto node_primitive = (lite::Primitive *) (prim->GetPrimitive()); + MS_EXCEPTION_IF_NULL(node_primitive); + auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor); + if (0 != ret) { + MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name; + return ret; + } + kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, node_primitive->Type()}; + + auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor, + node_primitive, &context, desc); + if (nullptr == kernel) { + MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name; + return -1; + } + kernel->train(); + auto *kernel_info = anf_register.cnode->kernel_info(); + std::shared_ptr kernel_mod(kernel); + kernel_info->set_kernel_mod(kernel_mod); + } + return 0; +} +} // namespace session +} // namespace mindspore + diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h new file mode 100644 index 00000000000..d9b026d55f7 --- /dev/null +++ b/mindspore/lite/src/train/train_session.h @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ +#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ +#include +#include +#include +#include +#include "backend/session/session_basic.h" +#include "backend/session/kernel_graph.h" +#include "mindspore/lite/src/train/lite_kernel_runtime.h" +#include "backend/session/session_factory.h" +namespace mindspore { +namespace lite::tensor { +class Tensor; +} +namespace session { +struct KernelRelation { + std::string node_full_name; + std::vector input_tensor; + std::vector output_tensor; + CNodePtr cnode; +}; + +class TrainSession : public SessionBasic { + public: + TrainSession() : SessionBasic() {} + ~TrainSession() override = default; + void Init(uint32_t device_id) override { + SessionBasic::Init(device_id); + context_ = std::make_shared(kCPUDevice, device_id); + } + + GraphId CompileGraph(NotNull func_graph) override; + + void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; + + private: + GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; + GraphId CompileGraph(const char *model_buf, size_t size); + std::shared_ptr ConstructKernelGraph(const FuncGraphPtr &func_graph); + int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph); + void SetKernelInfo(const KernelGraph *kernel_graph); + int BuildKernel(const KernelGraph *kernel_graph); + lite::LiteInferKernelRuntime runtime_; + std::map kernel_relation_infos_; +}; +MS_REG_SESSION(kCPUDevice, TrainSession); +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ + diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt new file mode 100644 index 00000000000..7311bfec1a3 --- /dev/null +++ b/mindspore/lite/test/CMakeLists.txt @@ -0,0 +1,298 @@ +set(TEST_DIR ${TOP_DIR}/tests/ut/cpp) +set(LITE_DIR ${TOP_DIR}/mindspore/lite) +include_directories(${TEST_DIR}) +include_directories(${LITE_DIR}/tools) +include_directories(${LITE_DIR}/lite) +include(${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake/dependency_gtest.cmake) + +### anf src +set(ANF_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/meta_tensor.cc + ${CCSRC_DIR}/gvar/logging_level.cc + ${CCSRC_DIR}/gvar/typeid_manager.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/base/base.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/utils/log_adapter.cc + ) +if(BUILD_CONVERTER) + set(ANF_SRC + ${ANF_SRC} + # core/abstract + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/abstract/abstract_function.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/abstract/analysis_context.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/abstract/param_validator.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/abstract/abstract_value.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/abstract/dshape.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/abstract/utils.cc + # core/base + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/base/base_ref.cc + # core/ir + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/anf.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/anf_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/meta_func_graph.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/func_graph.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/graph_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../ccsrc/utils/func_graph_cloner.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/func_graph_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/manager.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/primitive.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/tensor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/visitor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/meta_tensor_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/named.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/scope.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/value_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/container.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/empty.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/number.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/ref.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/type.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/ir/dtype/type_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/utils/any.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/utils/symbolic.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../core/utils/misc.cc + ## ccsrc + ${CCSRC_DIR}/debug/info.cc + ${CCSRC_DIR}/debug/trace_base.cc + ${CCSRC_DIR}/debug/trace_info.cc + ${CCSRC_DIR}/debug/label.cc + ${CCSRC_DIR}/debug/draw.cc + ${CCSRC_DIR}/pybind_api/export_flags.cc + ${CCSRC_DIR}/utils/profile.cc + ${CCSRC_DIR}/utils/context/ms_context.cc + ${CCSRC_DIR}/frontend/parallel/costmodel_context.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../src/common/graph_utils_extends.cc + ) +else() + set(ANF_SRC + ${ANF_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/../src/ir/meta_tensor_extends.cc + ) +endif() +### cpu kernel +file(GLOB_RECURSE KERNEL_OP_SRC + ${LITE_DIR}/src/runtime/kernel/arm/base/*.cc + ${LITE_DIR}/src/runtime/kernel/arm/fp32/*.cc + ${LITE_DIR}/src/runtime/kernel/arm/int8/*.cc + ${LITE_DIR}/src/runtime/kernel/arm/opclib/*.cc + ${LITE_DIR}/src/runtime/kernel/arm/opclib/fp32/*.cc + ${LITE_DIR}/src/runtime/kernel/arm/opclib/int8/*.cc + ${LITE_DIR}/src/runtime/kernel/arm/opclib/quantization/*.cc + ) +if (PLATFORM_ARM64) + # assembly + file(GLOB_RECURSE TEST_ASSEMBLY_SRC ${LITE_DIR}/src/runtime/kernel/arm/opclib/assembly/arm64/*.s + ${LITE_DIR}/src/runtime/kernel/arm/opclib/assembly/arm64/*.S) + + set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) + set(KERNEL_OP_SRC + ${KERNEL_OP_SRC} + ${TEST_ASSEMBLY_SRC} + ) +endif() +if (PLATFORM_ARM32) + # assembly + set(GLOB_RECURSE TEST_ASSEMBLY_SRC + ${LITE_DIR}/src/runtime/kernel/arm/opclib/assembly/arm32/*.S) + set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) + set(KERNEL_OP_SRC + ${KERNEL_OP_SRC} + ${TEST_ASSEMBLY_SRC} + ) +endif() +if (ENABLE_FP16) + set(KERNEL_OP_SRC + ${KERNEL_OP_SRC} + ${LITE_DIR}/src/runtime/kernel/arm/fp16/convolution_fp16.cc + ${LITE_DIR}/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc + ) +endif () +### gpu kernel +if (SUPPORT_GPU) + set(KERNEL_OP_SRC + ${KERNEL_OP_SRC} + ${LITE_DIR}/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc + ${LITE_DIR}/src/runtime/kernel/opencl/utils.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/arithmetic.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/convolution.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/pooling2d.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/matmul.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/softmax.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/concat.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc + ) +endif() +### runtime framework +file(GLOB_RECURSE OPS_SRC ${LITE_DIR}/src/ops/*.cc) +set(TEST_LITE_SRC + ${ANF_SRC} + ${OPS_SRC} + ${KERNEL_OP_SRC} + ${LITE_DIR}/src/runtime/allocator.cc + ${LITE_DIR}/src/runtime/runtime_api.cc + ${LITE_DIR}/src/runtime/thread_pool.cc + ${LITE_DIR}/src/runtime/workspace_pool.cc + ${LITE_DIR}/src/ir/tensor.cc + ${LITE_DIR}/src/context.cc + ${LITE_DIR}/src/executor.cc + ${LITE_DIR}/src/kernel_factory.cc + ${LITE_DIR}/src/kernel_registry.cc + ${LITE_DIR}/src/lite_kernel.cc + ${LITE_DIR}/src/lite_session.cc + ${LITE_DIR}/src/model.cc + ${LITE_DIR}/src/model_impl.cc + ${LITE_DIR}/src/populate_parameter.cc + ${LITE_DIR}/src/scheduler.cc + ${LITE_DIR}/src/common/graph_util.cc + ${LITE_DIR}/src/common/file_utils.cc + ${LITE_DIR}/src/common/utils.cc + ${LITE_DIR}/tools/common/graph_util.cc + ${LITE_DIR}/tools/common/tensor_util.cc + ${LITE_DIR}/tools/common/node_util.cc + ${LITE_DIR}/tools/common/flag_parser.cc + ${LITE_DIR}/tools/common/storage.cc + ${LITE_DIR}/tools/benchmark/benchmark.cc + ${LITE_DIR}/test/benchmark_test.cc + ) +### gpu runtime +if (SUPPORT_GPU) + include_directories(${TOP_DIR}/third_party/OpenCL-Headers) + include_directories(${TOP_DIR}/third_party/OpenCL-CLHPP/include) + set(OPENCL_RUNTIME_SRC + ${LITE_DIR}/src/runtime/opencl/opencl_allocator.cc + ${LITE_DIR}/src/runtime/opencl/opencl_executor.cc + ${LITE_DIR}/src/runtime/opencl/opencl_runtime.cc + ${LITE_DIR}/src/runtime/opencl/opencl_wrapper.cc + ) + set(TEST_LITE_SRC + ${TEST_LITE_SRC} + ${OPENCL_RUNTIME_SRC} + ) +endif() +### converter +if(BUILD_CONVERTER) + set(TEST_LITE_SRC + ${TEST_LITE_SRC} + ${LITE_DIR}/tools/converter/optimizer.cc + ${LITE_DIR}/src/common/anf_importer/anf_importer.cc + ${LITE_DIR}/src/common/anf_importer/import_from_meta_graphT.cc + ${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc + ${LITE_DIR}/tools/converter/anf_transform.cc + ${LITE_DIR}/tools/converter/graphdef_transform.cc + ${LITE_DIR}/tools/converter/converter_flags.cc + ${LITE_DIR}/tools/converter/converter.cc + ${LITE_DIR}/tools/converter/parser/onnx/onnx.pb.cc + ${LITE_DIR}/test/converter_test.cc + ${LITE_DIR}/src/gllo/common/node_pass.cc + ${LITE_DIR}/src/gllo/common/optimizer.cc + ${LITE_DIR}/src/gllo/common/pass_manager.cc + ${LITE_DIR}/src/gllo/common/pattern_engine.cc + ${LITE_DIR}/src/gllo/common/visit.cc + ${LITE_DIR}/src/gllo/common/utils.cc + ${LITE_DIR}/src/gllo/fusion/conv_biasadd_fusion.cc + ) +endif() +### train +if (SUPPORT_TRAIN) + set(TEST_LITE_SRC + ${TEST_LITE_SRC} + # ${SRC_DIR}/common/trans.cc + # ${SRC_DIR}/common/lite/trans_extends.cc + # ${SRC_DIR}/kernel/kernel_build_info.cc + # ${SRC_DIR}/utils/lite/base_ref_utils.cc + # ${SRC_DIR}/session/lite/anf_runtime_algorithm_extends.cc + # ${SRC_DIR}/session/lite/session_basic_extends.cc + # ${SRC_DIR}/session/anf_runtime_algorithm.cc + # ${SRC_DIR}/session/anf_runtime_algorithm.cc + # ${SRC_DIR}/session/session_basic.cc + # ${SRC_DIR}/session/kernel_graph.cc + # ${SRC_DIR}/session/session_factory.cc + # ${SRC_DIR}/device/kernel_info.cc + # ${SRC_DIR}/device/kernel_runtime.cc + # ${SRC_DIR}/device/lite/kernel_runtime_extends.cc + ${LITE_DIR}/src/common/anf_importer/anf_importer.cc + ${LITE_DIR}/src/common/anf_importer/import_from_meta_graph.cc + ${LITE_DIR}/src/ir/primitive_value.cc + ${LITE_DIR}/src/train/lite_kernel_runtime.cc + ${LITE_DIR}/src/train/train_session.cc + ${LITE_DIR}/src/train/model_impl.cc + ) +else() + set(TEST_LITE_SRC + ${TEST_LITE_SRC} + ${LITE_DIR}/src/lite_session.cc + ) +endif() +### test src +file(GLOB_RECURSE TEST_CASE_KERNEL_SRC + ${TEST_DIR}/kernel/cpu/arm/fp32/*.cc + ${TEST_DIR}/kernel/cpu/arm/int8/*.cc +) + +set(TEST_SRC + ${TEST_LITE_SRC} + ${TEST_CASE_KERNEL_SRC} + ${TEST_DIR}/common/common_test.cc + ${TEST_DIR}/common/test_lite_main.cc + ${TEST_DIR}/kernel/cpu/arm/common/pack_tests.cc + ${TEST_DIR}/device/cpu/arm/infer_test.cc +# ${TEST_DIR}/device/cpu/arm/graph_test.cc +) + +if (SUPPORT_TRAIN) + set(TEST_SRC + ${TEST_SRC} + ${TEST_DIR}/device/cpu/arm/train_test.cc + ) +else() + set(TEST_SRC + ${TEST_SRC} + ${TEST_DIR}/device/cpu/arm/infer_test.cc + ) +endif() + +if (SUPPORT_GPU) + set(TEST_SRC + ${TEST_SRC} + ${TEST_DIR}/device/opencl/opencl_infer_tests.cc + ${TEST_DIR}/kernel/opencl/utils_cl_tests.cc + ${TEST_DIR}/kernel/opencl/arithmetic_tests.cc + ${TEST_DIR}/kernel/opencl/convolution_tests.cc + ${TEST_DIR}/kernel/opencl/depthwise_conv2d_tests.cc + ${TEST_DIR}/kernel/opencl/matmul_tests.cc + ${TEST_DIR}/kernel/opencl/max_pooling_cl_tests.cc + ${TEST_DIR}/kernel/opencl/avg_pooling_cl_tests.cc + ${TEST_DIR}/kernel/opencl/softmax_cl_tests.cc + ${TEST_DIR}/kernel/opencl/concat_tests.cc + ${TEST_DIR}/kernel/opencl/conv2d_transpose_tests.cc + ) +endif() + +if (ENABLE_FP16) + set(TEST_SRC + ${TEST_SRC} + ${TEST_DIR}/kernel/cpu/arm/fp16/convolution_fp16_tests.cc) +endif () + + +add_executable(lite-test ${TEST_SRC}) + +target_link_libraries(lite-test dl ${SECUREC_LIBRARY} ${GTEST_LIBRARY} mindspore::json) +if (BUILD_CONVERTER) + target_link_libraries(lite-test + anf_exporter_mid + tflite_parser_mid + caffe_parser_mid + node_mid + graph_pass_mid + fusion_mid + quantizer_mid + pthread + protobuf + mindspore::eigen + ) +endif() + diff --git a/mindspore/lite/test/benchmark_test.cc b/mindspore/lite/test/benchmark_test.cc new file mode 100644 index 00000000000..dfade90210d --- /dev/null +++ b/mindspore/lite/test/benchmark_test.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "tests/ut/cpp/common/common_test.h" +#include "benchmark/benchmark.h" + +namespace mindspore { +namespace lite { +class BenchmarkTest : public UT::Common { + public: + BenchmarkTest() {} +}; + +TEST_F(BenchmarkTest, TestVideo) { + const char *argv[] = {"./benchmark", "--modelPath=./models/hiai_label_and_video.ms"}; + auto status = RunBenchmark(2, argv); + ASSERT_EQ(status, RET_OK); +} + +TEST_F(BenchmarkTest, TestOCR_02) { + const char *argv[] = {"./benchmark", "--modelPath=./models/hiai_cv_focusShootOCRMOdel_02.ms"}; + auto status = RunBenchmark(2, argv); + ASSERT_EQ(status, RET_OK); +} + +TEST_F(BenchmarkTest, TestHebing) { + const char *argv[] = {"./benchmark", "--modelPath=./models/model_hebing_3branch.ms"}; + auto status = RunBenchmark(2, argv); + ASSERT_EQ(status, RET_OK); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/test/converter_test.cc b/mindspore/lite/test/converter_test.cc new file mode 100644 index 00000000000..af78da31600 --- /dev/null +++ b/mindspore/lite/test/converter_test.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "converter/converter.h" +#include "tests/ut/cpp/common/common_test.h" + +namespace mindspore { +namespace lite { +class ConverterTest : public UT::Common { + public: + ConverterTest() {} +}; + +TEST_F(ConverterTest, TestLenet) { + const char *argv[] = {"./converter", "--fmk=MS", "--modelFile=./models/lenet_bin.pb", + "--outputFile=./models/lenet_bin"}; + auto status = RunConverter(4, argv); + ASSERT_EQ(status, RET_OK); +} + +TEST_F(ConverterTest, TestVideo) { + const char *argv[] = {"./converter", "--fmk=TFLITE", "--modelFile=./models/hiai_label_and_video.tflite", + "--outputFile=./models/hiai_label_and_video"}; + auto status = RunConverter(4, argv); + ASSERT_EQ(status, RET_OK); +} + +TEST_F(ConverterTest, TestOCR_02) { + const char *argv[] = {"./converter", "--fmk=TFLITE", "--modelFile=./models/hiai_cv_focusShootOCRMOdel_02.tflite", + "--outputFile=./models/hiai_cv_focusShootOCRMOdel_02"}; + auto status = RunConverter(4, argv); + ASSERT_EQ(status, RET_OK); +} + +TEST_F(ConverterTest, TestHebing) { + const char *argv[] = {"./converter", "--fmk=CAFFE", "--modelFile=./models/model_hebing_3branch.caffemodel", + "--weightFile=./models/model_hebing_3branch.prototxt", + "--outputFile=./models/model_hebing_3branch"}; + auto status = RunConverter(5, argv); + ASSERT_EQ(status, RET_OK); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/benchmark/CMakeLists.txt b/mindspore/lite/tools/benchmark/CMakeLists.txt new file mode 100644 index 00000000000..0f2da58425f --- /dev/null +++ b/mindspore/lite/tools/benchmark/CMakeLists.txt @@ -0,0 +1,22 @@ +# add shared link library +set(COMMON_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc + ) + +add_executable(benchmark + ${CMAKE_CURRENT_SOURCE_DIR}/main.cc + ${CMAKE_CURRENT_SOURCE_DIR}/benchmark.cc + ${COMMON_SRC}) + +if (PLATFORM_ARM32 OR PLATFORM_ARM64) + target_link_libraries(benchmark mindspore-lite ${SECUREC_LIBRARY}) +else() + target_link_libraries(benchmark mindspore-lite ${SECUREC_LIBRARY} pthread) +endif() + +target_link_libraries(benchmark + mindspore::json +# mindspore::eigen + ) diff --git a/mindspore/lite/tools/benchmark/benchmark.cc b/mindspore/lite/tools/benchmark/benchmark.cc new file mode 100644 index 00000000000..b648a70980c --- /dev/null +++ b/mindspore/lite/tools/benchmark/benchmark.cc @@ -0,0 +1,531 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/benchmark/benchmark.h" +#define __STDC_FORMAT_MACROS +#include +#undef __STDC_FORMAT_MACROS +#include +#include +#include +#include "src/common/common.h" +#include "include/ms_tensor.h" +#include "include/context.h" + +namespace mindspore { +namespace lite { +int Benchmark::GenerateRandomData(size_t size, void *data) { + MS_ASSERT(data != nullptr); + char *castedData = static_cast(data); + for (size_t i = 0; i < size; i++) { + castedData[i] = static_cast(i); + } + return 0; +} + +int Benchmark::GenerateInputData() { + for (auto tensor : msInputs) { + MS_ASSERT(tensor != nullptr); + auto inputData = tensor->MutableData(); + if (inputData == nullptr) { + MS_LOG(ERROR) << "MallocData for inTensor failed"; + return RET_ERROR; + } + MS_ASSERT(tensor->GetData() != nullptr); + auto tensorByteSize = tensor->Size(); + auto status = GenerateRandomData(tensorByteSize, inputData); + if (status != 0) { + MS_LOG(ERROR) << "GenerateRandomData for inTensor failed %d" << status; + return status; + } + } + return 0; +} + +int Benchmark::LoadInput() { + if (_flags->inDataPath.empty()) { + auto status = GenerateInputData(); + if (status != 0) { + MS_LOG(ERROR) << "Generate input data error " << status; + return status; + } + } else { + auto status = ReadInputFile(); + if (status != 0) { + MS_LOG(ERROR) << "ReadInputFile error, " << status; + return status; + } + } + return 0; +} + +int Benchmark::ReadInputFile() { + if (msInputs.empty()) { + return 0; + } + + if (this->_flags->inDataType == kImage) { + // int cvFlags; + // if (inTensor->Channel() == 3) { + // cvFlags = 0; // cv::IMREAD_COLOR; + // } else if (inTensor->Channel() == 1) { + // cvFlags = 1; // cv::IMREAD_GRAYSCALE; + // } else { + // MS_LOG(ERROR) << "Image mode only support imgChannel == 1 or 3, imgChannel : %lld", (long + // long)inTensor->Channel(); return RET_PARAM_INVALID; + // } + // todo fill inTensor->GetData() + } else { + for (auto i = 0; i < _flags->input_data_list.size(); i++) { + auto cur_tensor = msInputs.at(i); + MS_ASSERT(cur_tensor != nullptr); + size_t size; + char *binBuf = ReadFile(_flags->input_data_list[i].c_str(), &size); + auto tensorDataSize = cur_tensor->Size(); + if (size != tensorDataSize) { + MS_LOG(ERROR) << "Input binary file size error, required: %zu, in fact: %zu" << tensorDataSize << size; + return RET_ERROR; + } + auto inputData = cur_tensor->MutableData(); + memcpy(inputData, binBuf, tensorDataSize); + } + } + return 0; +} + +// calibData is FP32 +int Benchmark::ReadCalibData() { + const char *calibDataPath = _flags->calibDataPath.c_str(); + // read calib data + std::ifstream inFile(calibDataPath); + if (!inFile.good()) { + MS_LOG(ERROR) << "file: " << calibDataPath << " is not exist"; + return 1; + } + + if (!inFile.is_open()) { + MS_LOG(ERROR) << "file: " << calibDataPath << " open failed"; + inFile.close(); + return 1; + } + + std::string line; + + MS_LOG(INFO) << "Start reading calibData file"; + std::string tensorName; + while (!inFile.eof()) { + getline(inFile, line); + std::stringstream stringLine1(line); + size_t dim = 0; + stringLine1 >> tensorName >> dim; + std::vector dims; + size_t shapeSize = 1; + for (size_t i = 0; i < dim; i++) { + size_t tmpDim; + stringLine1 >> tmpDim; + dims.push_back(tmpDim); + shapeSize *= tmpDim; + } + + getline(inFile, line); + std::stringstream stringLine2(line); + std::vector tensorData; + for (size_t i = 0; i < shapeSize; i++) { + float tmpData; + stringLine2 >> tmpData; + tensorData.push_back(tmpData); + } + + auto *checkTensor = new CheckTensor(dims, tensorData); + this->calibData.insert(std::make_pair(tensorName, checkTensor)); + } + inFile.close(); + MS_LOG(INFO) << "Finish reading calibData file"; + return 0; +} + +// tensorData need to be converter first +float Benchmark::CompareData(const std::string &nodeName, std::vector msShape, float *msTensorData) { + auto iter = this->calibData.find(nodeName); + if (iter != this->calibData.end()) { + std::vector castedMSShape; + size_t shapeSize = 1; + for (int64_t dim : msShape) { + castedMSShape.push_back(size_t(dim)); + shapeSize *= dim; + } + + CheckTensor *calibTensor = iter->second; + if (calibTensor->shape != castedMSShape) { + std::ostringstream oss; + oss << "Shape of mslite output("; + for (auto dim : castedMSShape) { + oss << dim << ","; + } + oss << ") and shape source model output("; + for (auto dim : calibTensor->shape) { + oss << dim << ","; + } + oss << ") are different"; + MS_LOG(ERROR) << "%s", oss.str().c_str(); + return -1; + } + size_t errorCount = 0; + float meanError = 0; + std::cout << "Data of node " << nodeName << " : "; + for (size_t j = 0; j < shapeSize; j++) { + if (j < 50) { + std::cout << msTensorData[j] << " "; + } + + auto tolerance = absoluteTolerance + relativeTolerance * fabs(calibTensor->data.at(j)); + auto absoluteError = std::fabs(msTensorData[j] - calibTensor->data.at(j)); + if (absoluteError > tolerance) { + // just assume that atol = rtol + meanError += absoluteError / (fabs(calibTensor->data.at(j)) + 1); + errorCount++; + } + } + std::cout << std::endl; + if (meanError > 0.0f) { + errorCount = 0; + meanError /= errorCount; + } + + if (meanError <= 0.0000001) { + std::cout << "Mean bias of node " << nodeName << " : 0%" << std::endl; + } else { + std::cout << "Mean bias of node " << nodeName << " : " << meanError * 100 << "%" << std::endl; + } + return meanError; + } else { + MS_LOG(INFO) << "%s is not in Source Model output", nodeName.c_str(); + return -1; + } +} + +int Benchmark::CompareOutput() { + std::cout << "================ Comparing Output data ================" << std::endl; + float totalBias = 0; + int totalSize = 0; + bool hasError = false; + for (const auto &calibTensor : calibData) { + std::string nodeName = calibTensor.first; + auto tensors = session->GetOutputsByName(nodeName); + for (auto tensor : tensors) { + MS_ASSERT(tensor->GetDataType() == DataType_DT_FLOAT); + MS_ASSERT(tensor->GetData() != nullptr); + float bias = CompareData(nodeName, tensor->shape(), static_cast(tensor->MutableData())); + if (bias >= 0) { + totalBias += bias; + totalSize++; + } else { + hasError = true; + break; + } + } + } + + if (!hasError) { + float meanBias; + if (totalSize != 0) { + meanBias = totalBias / totalSize * 100; + } else { + meanBias = 0; + } + + std::cout << "Mean bias of all nodes: " << meanBias << "%" << std::endl; + std::cout << "=======================================================" << std::endl << std::endl; + + if (meanBias > this->_flags->accuracyThreshold) { + MS_LOG(ERROR) << "Mean bias of all nodes is too big: " << meanBias << "%%"; + return 1; + } else { + return 0; + } + } else { + MS_LOG(ERROR) << "Error in CompareData"; + std::cout << "=======================================================" << std::endl << std::endl; + return 1; + } +} + +int Benchmark::MarkPerformance() { + MS_LOG(INFO) << "Running warm up loops..."; + for (int i = 0; i < _flags->warmUpLoopCount; i++) { + auto status = session->RunGraph(); + if (status != 0) { + MS_LOG(ERROR) << "Inference error %d" << status; + return status; + } + } + + MS_LOG(INFO) << "Running benchmark loops..."; + uint64_t timeMin = 1000000; + uint64_t timeMax = 0; + uint64_t timeAvg = 0; + + for (int i = 0; i < _flags->loopCount; i++) { + session->BindThread(true); + auto start = GetTimeUs(); + auto status = session->RunGraph(); + if (status != 0) { + MS_LOG(ERROR) << "Inference error %d" << status; + return status; + } + + auto end = GetTimeUs(); + auto time = end - start; + timeMin = std::min(timeMin, time); + timeMax = std::max(timeMax, time); + timeAvg += time; + + session->BindThread(false); + } + if (_flags->loopCount > 0) { + timeAvg /= _flags->loopCount; + // MS_LOG(INFO) << "CSV:%s:%d:%f:%f:%f\n", _flags->modelPath.substr(_flags->modelPath.find_last_of(DELIM_SLASH) + + // 1).c_str(), + // _flags->numThreads, timeMin / 1000.0f, timeMax / 1000.0f, timeAvg / 1000.0f); + // MS_LOG(INFO) <<"Modle = %s, numThreads = %d, MinRunTime = %f ms, MaxRuntime = %f ms, AvgRunTime = %f ms", + // _flags->modelPath.substr(_flags->modelPath.find_last_of(DELIM_SLASH) + 1).c_str(), _flags->numThreads, + // timeMin / 1000.0f, timeMax / 1000.0f, timeAvg / 1000.0f); + + printf("CSV:%s:%d:%f:%f:%f\n", _flags->modelPath.substr(_flags->modelPath.find_last_of(DELIM_SLASH) + 1).c_str(), + _flags->numThreads, timeMin / 1000.0f, timeMax / 1000.0f, timeAvg / 1000.0f); + printf("Modle = %s, numThreads = %d, MinRunTime = %f ms, MaxRuntime = %f ms, AvgRunTime = %f ms\n", + _flags->modelPath.substr(_flags->modelPath.find_last_of(DELIM_SLASH) + 1).c_str(), _flags->numThreads, + timeMin / 1000.0f, timeMax / 1000.0f, timeAvg / 1000.0f); + } + return 0; +} + +int Benchmark::MarkAccuracy() { + MS_LOG(INFO) << "MarkAccuracy"; + for (size_t i = 0; i < msInputs.size(); i++) { + auto inData = reinterpret_cast(msInputs.at(i)->MutableData()); + std::cout << "InData" << i << ": "; + for (size_t j = 0; j < 20; j++) { + std::cout << inData[j] << " "; + } + std::cout << std::endl; + } + auto status = session->RunGraph(); + if (status != 0) { + MS_LOG(ERROR) << "Inference error %d" << status; + return status; + } + + ReadCalibData(); + if (cleanData) { + for (auto &msOutput : msOutputs) { + for (auto &outputTensor : msOutput.second) { + delete outputTensor; + } + } + msOutputs.clear(); + } + return 0; +} + +int Benchmark::RunBenchmark(const std::string &deviceType) { + auto startPrepareTime = GetTimeUs(); + // Load graph + std::string modelName = _flags->modelPath.substr(_flags->modelPath.find_last_of(DELIM_SLASH) + 1); + + MS_LOG(INFO) << "start reading model file"; + size_t size = 0; + char *graphBuf = ReadFile(_flags->modelPath.c_str(), &size); + if (graphBuf == nullptr) { + MS_LOG(ERROR) << "Load graph failed while running %s", modelName.c_str(); + return 1; + } + auto model = lite::Model::Import(graphBuf, size); + auto context = new lite::Context; + if (_flags->device == "CPU") { + context->deviceCtx.type = lite::DT_CPU; + } else { + context->deviceCtx.type = lite::DT_NPU; + } + + if (_flags->cpuBindMode == -1) { + context->cpuBindMode = MID_CPU; + } else if (_flags->cpuBindMode == 0) { + context->cpuBindMode = HIGHER_CPU; + } else { + context->cpuBindMode = NO_BIND; + } + context->threadNum = _flags->numThreads; + session = session::LiteSession::CreateSession(context); + auto ret = session->CompileGraph(model.get()); + if (ret != RET_OK) { + return ret; + } + msInputs = session->GetInputs(); + auto endPrepareTime = GetTimeUs(); +#if defined(__arm__) + MS_LOG(INFO) << "PrepareTime = %lld ms, " << (endPrepareTime - startPrepareTime) / 1000; + printf("PrepareTime = %lld ms, ", (endPrepareTime - startPrepareTime) / 1000); +#else + MS_LOG(INFO) << "PrepareTime = %ld ms, " << (endPrepareTime - startPrepareTime) / 1000; + printf("PrepareTime = %ld ms, ", (endPrepareTime - startPrepareTime) / 1000); +#endif + + // Load input + MS_LOG(INFO) << "start generate input data"; + auto status = LoadInput(); + if (status != 0) { + MS_LOG(ERROR) << "Generate input data error"; + return status; + } + if (!_flags->calibDataPath.empty()) { + status = MarkAccuracy(); + if (status != 0) { + MS_LOG(ERROR) << "Run MarkAccuracy error: %d" << status; + return status; + } + } else { + status = MarkPerformance(); + if (status != 0) { + MS_LOG(ERROR) << "Run MarkPerformance error: %d" << status; + return status; + } + } + + if (cleanData) { + for (auto &msInput : msInputs) { + delete msInput; + } + msInputs.clear(); + for (auto &data : calibData) { + data.second->shape.clear(); + data.second->data.clear(); + delete data.second; + } + calibData.clear(); + } + + delete graphBuf; + return 0; +} + +void BenchmarkFlags::InitInputDataList() { + char *input_list = new char[this->inDataPath.length() + 1]; + snprintf(input_list, this->inDataPath.length() + 1, "%s", this->inDataPath.c_str()); + char *cur_input; + const char *split_c = ","; + cur_input = strtok(input_list, split_c); + while (cur_input) { + input_data_list.emplace_back(cur_input); + cur_input = strtok(nullptr, split_c); + } + delete[] input_list; +} + +void BenchmarkFlags::InitResizeDimsList() { + std::string content; + content = this->resizeDimsIn; + std::vector shape; + auto shapeStrs = StringSplit(content, std::string(DELIM_COLON)); + for (const auto &shapeStr : shapeStrs) { + shape.clear(); + auto dimStrs = StringSplit(shapeStr, std::string(DELIM_COMMA)); + std::cout << "Resize Dims: "; + for (const auto &dimStr : dimStrs) { + std::cout << dimStr << " "; + shape.emplace_back(static_cast(std::stoi(dimStr))); + } + std::cout << std::endl; + this->resizeDims.emplace_back(shape); + } +} + +int Benchmark::Init() { + if (this->_flags == nullptr) { + return 1; + } + MS_LOG(INFO) << "ModelPath = " << this->_flags->modelPath; + MS_LOG(INFO) << "InDataPath = " << this->_flags->inDataPath; + MS_LOG(INFO) << "InDataType = " << this->_flags->inDataTypeIn; + MS_LOG(INFO) << "LoopCount = " << this->_flags->loopCount; + MS_LOG(INFO) << "DeviceType = " << this->_flags->device; + MS_LOG(INFO) << "AccuracyThreshold = " << this->_flags->accuracyThreshold; + MS_LOG(INFO) << "WarmUpLoopCount = " << this->_flags->warmUpLoopCount; + MS_LOG(INFO) << "NumThreads = " << this->_flags->numThreads; + MS_LOG(INFO) << "calibDataPath = " << this->_flags->calibDataPath; + if (this->_flags->cpuBindMode == -1) { + MS_LOG(INFO) << "cpuBindMode = MID_CPU"; + } else if (this->_flags->cpuBindMode == 1) { + MS_LOG(INFO) << "cpuBindMode = HIGHER_CPU"; + } else { + MS_LOG(INFO) << "cpuBindMode = NO_BIND"; + } + + this->_flags->inDataType = this->_flags->inDataTypeIn == "img" ? kImage : kBinary; + + if (_flags->modelPath.empty()) { + MS_LOG(ERROR) << "modelPath is required"; + return 1; + } + _flags->InitInputDataList(); + _flags->InitResizeDimsList(); + if (!_flags->resizeDims.empty() && _flags->resizeDims.size() != _flags->input_data_list.size()) { + MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath"; + return 1; + } + + return 0; +} + +int RunBenchmark(int argc, const char **argv) { + BenchmarkFlags flags; + Option err = flags.ParseFlags(argc, argv); + + if (err.IsSome()) { + std::cerr << err.Get() << std::endl; + std::cerr << flags.Usage() << std::endl; + return -1; + } + + if (flags.help) { + std::cerr << flags.Usage() << std::endl; + return 0; + } + + Benchmark mBenchmark(&flags); + auto status = mBenchmark.Init(); + if (status != 0) { + MS_LOG(ERROR) << "Benchmark init Error : " << status; + return 1; + } + + if (flags.device == "NPU") { + status = mBenchmark.RunBenchmark("NPU"); + } else { + status = mBenchmark.RunBenchmark("CPU"); + } + + if (status != 0) { + MS_LOG(ERROR) << "Run Benchmark Error : " << status; + return 1; + } + + MS_LOG(INFO) << "end of benchmark"; + return 0; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/benchmark/benchmark.h b/mindspore/lite/tools/benchmark/benchmark.h new file mode 100644 index 00000000000..787d5f1a287 --- /dev/null +++ b/mindspore/lite/tools/benchmark/benchmark.h @@ -0,0 +1,146 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINNIE_BENCHMARK_BENCHMARK_H_ +#define MINNIE_BENCHMARK_BENCHMARK_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "tools/common/flag_parser.h" +#include "src/common/file_utils.h" +#include "src/common/utils.h" +#include "schema/model_generated.h" +#include "include/model.h" +#include "include/lite_session.h" +#include "include/inference.h" + +namespace mindspore::lite { +enum MS_API InDataType { kImage = 0, kBinary = 1 }; + +constexpr float relativeTolerance = 0.01; +constexpr float absoluteTolerance = 0.01; + +struct MS_API CheckTensor { + CheckTensor(const std::vector &shape, const std::vector &data) { + this->shape = shape; + this->data = data; + } + std::vector shape; + std::vector data; +}; + +class MS_API BenchmarkFlags : public virtual FlagParser { + public: + BenchmarkFlags() { + // common + AddFlag(&BenchmarkFlags::modelPath, "modelPath", "Input model path", ""); + AddFlag(&BenchmarkFlags::inDataPath, "inDataPath", "Input data path, if not set, use random input", ""); + AddFlag(&BenchmarkFlags::inDataTypeIn, "inDataType", "Input data type. img | bin", "bin"); + AddFlag(&BenchmarkFlags::omModelPath, "omModelPath", "OM model path, only required when device is NPU", ""); + AddFlag(&BenchmarkFlags::device, "device", "CPU | NPU", "CPU"); + AddFlag(&BenchmarkFlags::cpuBindMode, "cpuBindMode", + "Input -1 for MID_CPU, 1 for HIGHER_CPU, 0 for NO_BIND, defalut value: 1", 1); + // MarkPerformance + AddFlag(&BenchmarkFlags::loopCount, "loopCount", "Run loop count", 10); + AddFlag(&BenchmarkFlags::numThreads, "numThreads", "Run threads number", 2); + AddFlag(&BenchmarkFlags::warmUpLoopCount, "warmUpLoopCount", "Run warm up loop", 3); + // MarkAccuracy + AddFlag(&BenchmarkFlags::calibDataPath, "calibDataPath", "Calibration data file path", ""); + AddFlag(&BenchmarkFlags::accuracyThreshold, "accuracyThreshold", "Threshold of accuracy", 0.5); + // Resize + AddFlag(&BenchmarkFlags::resizeDimsIn, "resizeDims", "Dims to resize to", ""); + } + + ~BenchmarkFlags() override = default; + + void InitInputDataList(); + + void InitResizeDimsList(); + + public: + // common + std::string modelPath; + std::string inDataPath; + std::vector input_data_list; + InDataType inDataType; + std::string inDataTypeIn; + int cpuBindMode = 1; + // MarkPerformance + int loopCount; + int numThreads; + int warmUpLoopCount; + // MarkAccuracy + std::string calibDataPath; + float accuracyThreshold; + // Resize + std::string resizeDimsIn; + std::vector> resizeDims; + + std::string omModelPath; + std::string device; +}; + +class MS_API Benchmark { + public: + explicit Benchmark(BenchmarkFlags *flags) : _flags(flags) {} + + virtual ~Benchmark() = default; + + int Init(); + int RunBenchmark(const std::string &deviceType = "NPU"); + // int RunNPUBenchmark(); + + private: + // call GenerateInputData or ReadInputFile to init inputTensors + int LoadInput(); + + // call GenerateRandomData to fill inputTensors + int GenerateInputData(); + + int GenerateRandomData(size_t size, void *data); + + int ReadInputFile(); + + int ReadCalibData(); + + int CompareOutput(); + + float CompareData(const std::string &nodeName, std::vector msShape, float *msTensorData); + + int MarkPerformance(); + + int MarkAccuracy(); + + private: + BenchmarkFlags *_flags; + session::LiteSession *session; + std::vector msInputs; + std::unordered_map> msOutputs; + std::unordered_map calibData; + bool cleanData = true; +}; + +int MS_API RunBenchmark(int argc, const char **argv); +} // namespace mindspore::lite +#endif // MINNIE_BENCHMARK_BENCHMARK_H_ + diff --git a/mindspore/lite/tools/benchmark/main.cc b/mindspore/lite/tools/benchmark/main.cc new file mode 100644 index 00000000000..10d2204783f --- /dev/null +++ b/mindspore/lite/tools/benchmark/main.cc @@ -0,0 +1,20 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/benchmark/benchmark.h" + +int main(int argc, const char **argv) { return mindspore::lite::RunBenchmark(argc, argv); } + diff --git a/mindspore/lite/tools/common/CMakeLists.txt b/mindspore/lite/tools/common/CMakeLists.txt new file mode 100755 index 00000000000..0250fa63782 --- /dev/null +++ b/mindspore/lite/tools/common/CMakeLists.txt @@ -0,0 +1,9 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + +add_library(converter_common_mid OBJECT + ${CMAKE_CURRENT_SOURCE_DIR}/graph_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/node_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/tensor_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/storage.cc + ) +set_target_properties(converter_common_mid PROPERTIES COMPILE_FLAGS "-Wno-unused-function") diff --git a/mindspore/lite/tools/common/converter_op_utils.h b/mindspore/lite/tools/common/converter_op_utils.h new file mode 100644 index 00000000000..20356b5ade5 --- /dev/null +++ b/mindspore/lite/tools/common/converter_op_utils.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_CONVERTER_COMMON_OP_UTILS_H_ +#define PREDICT_CONVERTER_COMMON_OP_UTILS_H_ + +#include +#include +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +inline schema::PrimitiveType GetCNodeTType(const schema::CNodeT &cNodeT) { return cNodeT.primitive->value.type; } +inline std::string GetCNodeTTypeName(const schema::CNodeT &cNodeT) { + return schema::EnumNamePrimitiveType(GetCNodeTType(cNodeT)); +} +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_CONVERTER_COMMON_OP_UTILS_H_ + diff --git a/mindspore/lite/tools/common/flag_parser.cc b/mindspore/lite/tools/common/flag_parser.cc new file mode 100755 index 00000000000..3ea4baac9bb --- /dev/null +++ b/mindspore/lite/tools/common/flag_parser.cc @@ -0,0 +1,180 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/common/flag_parser.h" + +namespace mindspore { +namespace lite { +// parse flags read from command line +Option FlagParser::ParseFlags(int argc, const char *const *argv, bool supportUnknown, + bool supportDuplicate) { + MS_ASSERT(argv != nullptr); + const int FLAG_PREFIX_LEN = 2; + // Get binary name + binName = GetFileName(argv[0]); + + std::multimap> keyValues; + for (int i = 1; i < argc; i++) { + std::string tmp = argv[i]; + Trim(&tmp); + const std::string flagItem(tmp); + + if (flagItem == "--") { + break; + } + + if (flagItem.find("--") == std::string::npos) { + continue; + } + + std::string key; + Option value = Option(None()); + + size_t pos = flagItem.find_first_of("="); + if (pos == std::string::npos && flagItem.find("--no-") != std::string::npos) { + key = flagItem.substr(FLAG_PREFIX_LEN); + } else if (pos == std::string::npos) { + key = flagItem.substr(FLAG_PREFIX_LEN); + } else { + key = flagItem.substr(FLAG_PREFIX_LEN, pos - FLAG_PREFIX_LEN); + value = Option(flagItem.substr(pos + 1)); + } + + keyValues.insert(std::pair>(key, value)); + } + + Option ret = Option(InnerParseFlags(&keyValues)); + if (ret.IsSome()) { + return Option(ret.Get()); + } + + return Option(None()); +} + +bool FlagParser::GetRealFlagName(std::string *flagName, const std::string &oriFlagName) { + MS_ASSERT(flagName != nullptr); + const int BOOL_TYPE_FLAG_PREFIX_LEN = 3; + bool opaque = false; + if (StartsWithPrefix(oriFlagName, "no-")) { + *flagName = oriFlagName.substr(BOOL_TYPE_FLAG_PREFIX_LEN); + opaque = true; + } else { + *flagName = oriFlagName; + } + return opaque; +} + +// Inner parse function +Option FlagParser::InnerParseFlags(std::multimap> *keyValues) { + MS_ASSERT(keyValues != nullptr); + for (auto it = keyValues->begin(); it != keyValues->end(); ++it) { + std::string flagName; + bool opaque = GetRealFlagName(&flagName, (*it).first); + Option flagValue = (*it).second; + + auto item = flags.find(flagName); + if (item == flags.end()) { + return Option(std::string(flagName + " is not a valid flag")); + } + FlagInfo *flag = &(item->second); + if (flag == nullptr) { + return Option("Failed: flag is nullptr"); + } + if (flag->isParsed) { + return Option("Failed: already parsed flag: " + flagName); + } + std::string tmpValue; + if (!flag->isBoolean) { + if (opaque) { + return Option(flagName + " is not a boolean type"); + } + if (flagValue.IsNone()) { + return Option("No value provided for non-boolean type: " + flagName); + } + tmpValue = flagValue.Get(); + } else { + if (flagValue.IsNone() || flagValue.Get().empty()) { + tmpValue = !opaque ? "true" : "false"; + } else if (!opaque) { + tmpValue = flagValue.Get(); + } else { + return Option(std::string("Boolean flag can not have non-empty value")); + } + } + // begin to parse value + Option ret = flag->parse(this, tmpValue); + if (ret.IsNone()) { + return Option("Failed to parse value for: " + flag->flagName); + } + flag->isParsed = true; + } + + // to check flags not given in command line but added as in constructor + for (auto &flag : flags) { + if (flag.second.isRequired && !flag.second.isParsed) { + return Option("Error, value of '" + flag.first + "' not provided"); + } + } + + return Option(None()); +} + +void Replaceall(std::string *str, const std::string &oldValue, const std::string &newValue) { + if (str == nullptr) { + // MS_LOG(ERROR)("Input str is nullptr"); + return; + } + while (true) { + std::string::size_type pos(0); + if ((pos = str->find(oldValue)) != std::string::npos) { + str->replace(pos, oldValue.length(), newValue); + } else { + break; + } + } +} + +std::string FlagParser::Usage(const Option &usgMsg) const { + // first line, brief of the usage + std::string usageString = usgMsg.IsSome() ? usgMsg.Get() + "\n" : ""; + // usage of bin name + usageString += usageMsg.IsNone() ? "usage: " + binName + " [options]\n" : usageMsg.Get() + "\n"; + // help line of help message, usageLine:message of parametors + std::string helpLine = ""; + std::string usageLine = ""; + uint32_t i = 0; + for (auto flag = flags.begin(); flag != flags.end(); flag++) { + std::string flagName = flag->second.flagName; + std::string helpInfo = flag->second.helpInfo; + // parameter line + std::string thisLine = flag->second.isBoolean ? " --[no-]" + flagName : " --" + flagName + "=VALUE"; + if (++i <= flags.size()) { + // add parameter help message of each line + thisLine += " " + helpInfo; + Replaceall(&helpInfo, "\n\r", "\n"); + usageLine += thisLine + "\n"; + } else { + // breif help message + helpLine = thisLine + " " + helpInfo + "\n"; + } + } + // total usage is brief of usage+ brief of bin + help message + brief of + // parameters + return usageString + helpLine + usageLine; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/common/flag_parser.h b/mindspore/lite/tools/common/flag_parser.h new file mode 100755 index 00000000000..e8e87ac699c --- /dev/null +++ b/mindspore/lite/tools/common/flag_parser.h @@ -0,0 +1,301 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_COMMON_FLAG_PARSER_H_ +#define PREDICT_COMMON_FLAG_PARSER_H_ + +#include +#include +#include +#include + +#include "src/common/utils.h" +#include "tools/common/option.h" + +namespace mindspore { +namespace lite { +struct FlagInfo; + +struct Nothing {}; + +class FlagParser { + public: + FlagParser() { AddFlag(&FlagParser::help, "help", "print usage message", false); } + + virtual ~FlagParser() {} + + // only support read flags from command line + virtual Option ParseFlags(int argc, const char *const *argv, bool supportUnknown = false, + bool supportDuplicate = false); + std::string Usage(const Option &usgMsg = Option(None())) const; + + template + void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 *t2); + template + void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2); + + // non-Option type fields in class + template + void AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 *t2); + + template + void AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2); + + template + void AddFlag(T Flags::*t, const std::string &flagName, const std::string &helpInfo); + + // Option-type fields + template + void AddFlag(Option Flags::*t, const std::string &flagName, const std::string &helpInfo); + bool help; + + protected: + template + void AddFlag(std::string Flags::*t1, const std::string &flagName, const std::string &helpInfo, const char *t2) { + AddFlag(t1, flagName, helpInfo, std::string(t2)); + } + + std::string binName; + Option usageMsg; + + private: + struct FlagInfo { + std::string flagName; + bool isRequired; + bool isBoolean; + std::string helpInfo; + bool isParsed; + std::function(FlagParser *, const std::string &)> parse; + }; + + inline void AddFlag(const FlagInfo &flag); + + // construct a temporary flag + template + void ConstructFlag(Option Flags::*t, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag); + + // construct a temporary flag + template + void ConstructFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag); + + Option InnerParseFlags(std::multimap> *values); + + bool GetRealFlagName(std::string *flagName, const std::string &oriFlagName); + + std::map flags; +}; + +// convert to std::string +template +Option ConvertToString(T Flags::*t, const FlagParser &baseFlag) { + const Flags *flag = dynamic_cast(&baseFlag); + if (flag != nullptr) { + return std::to_string(flag->*t); + } + + return Option(None()); +} + +// construct for a Option-type flag +template +void FlagParser::ConstructFlag(Option Flags::*t1, const std::string &flagName, const std::string &helpInfo, + FlagInfo *flag) { + if (flag == nullptr) { + // MS_LOGE("FlagInfo is nullptr"); + return; + } + flag->flagName = flagName; + flag->helpInfo = helpInfo; + + flag->isBoolean = typeid(T) == typeid(bool); + flag->isParsed = false; +} + +// construct a temporary flag +template +void FlagParser::ConstructFlag(T Flags::*t1, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag) { + if (flag == nullptr) { + // MS_LOGE("FlagInfo is nullptr"); + return; + } + if (t1 == nullptr) { + // MS_LOGE("t1 is nullptr"); + return; + } + flag->flagName = flagName; + flag->helpInfo = helpInfo; + flag->isBoolean = typeid(T) == typeid(bool); + flag->isParsed = false; +} + +inline void FlagParser::AddFlag(const FlagInfo &flagItem) { flags[flagItem.flagName] = flagItem; } + +template +void FlagParser::AddFlag(T Flags::*t, const std::string &flagName, const std::string &helpInfo) { + if (t == nullptr) { + // MS_LOGE("t1 is nullptr"); + return; + } + AddFlag(t, flagName, helpInfo, static_cast(nullptr)); +} + +template +void FlagParser::AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2) { + if (t1 == nullptr) { + // MS_LOGE("t1 is nullptr"); + return; + } + AddFlag(t1, flagName, helpInfo, &t2); +} + +// just for test +template +void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2) { + if (t1 == nullptr) { + // MS_LOGE("t1 is nullptr"); + return; + } + AddFlag(t1, flagName, helpInfo, &t2); +} + +template +void FlagParser::AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 *t2) { + if (t1 == nullptr) { + // MS_LOGE("t1 is nullptr"); + return; + } + + FlagInfo flagItem; + + // flagItem is as a output parameter + ConstructFlag(t1, flagName, helpInfo, flagItem); + flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option { + if (base != nullptr) { + Option ret = Option(GenericParseValue(value)); + if (ret.IsNone()) { + return Option(None()); + } else { + *t1 = ret.Get(); + } + } + + return Option(Nothing()); + }; + + if (t2 != nullptr) { + flagItem.isRequired = false; + *t1 = *t2; + } + + flagItem.helpInfo += + !helpInfo.empty() && helpInfo.find_last_of("\n\r") != helpInfo.size() - 1 ? " (default: " : "(default: "; + if (t2 != nullptr) { + flagItem.helpInfo += ToString(*t2).Get(); + } + flagItem.helpInfo += ")"; + + // add this flag to a std::map + AddFlag(flagItem); +} + +template +void FlagParser::AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 *t2) { + if (t1 == nullptr) { + // MS_LOGE("t1 is nullptr"); + return; + } + + Flags *flag = dynamic_cast(this); + if (flag == nullptr) { + return; + } + + FlagInfo flagItem; + + // flagItem is as a output parameter + ConstructFlag(t1, flagName, helpInfo, &flagItem); + flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option { + Flags *flag = dynamic_cast(base); + if (base != nullptr) { + Option ret = Option(GenericParseValue(value)); + if (ret.IsNone()) { + return Option(None()); + } else { + flag->*t1 = ret.Get(); + } + } + + return Option(Nothing()); + }; + + if (t2 != nullptr) { + flagItem.isRequired = false; + flag->*t1 = *t2; + } else { + flagItem.isRequired = true; + } + + flagItem.helpInfo += + !helpInfo.empty() && helpInfo.find_last_of("\n\r") != helpInfo.size() - 1 ? " (default: " : "(default: "; + if (t2 != nullptr) { + flagItem.helpInfo += ToString(*t2).Get(); + } + flagItem.helpInfo += ")"; + + // add this flag to a std::map + AddFlag(flagItem); +} + +// option-type add flag +template +void FlagParser::AddFlag(Option Flags::*t, const std::string &flagName, const std::string &helpInfo) { + if (t == nullptr) { + // MS_LOGE("t is nullptr"); + return; + } + + Flags *flag = dynamic_cast(this); + if (flag == nullptr) { + // MS_LOGE("dynamic_cast failed"); + return; + } + + FlagInfo flagItem; + // flagItem is as a output parameter + ConstructFlag(t, flagName, helpInfo, &flagItem); + flagItem.isRequired = false; + flagItem.parse = [t](FlagParser *base, const std::string &value) -> Option { + Flags *flag = dynamic_cast(base); + if (base != nullptr) { + Option ret = Option(GenericParseValue(value)); + if (ret.IsNone()) { + return Option(None()); + } else { + flag->*t = Option(Some(ret.Get())); + } + } + + return Option(Nothing()); + }; + + // add this flag to a std::map + AddFlag(flagItem); +} +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_COMMON_FLAG_PARSER_H_ + diff --git a/mindspore/lite/tools/common/graph_util.cc b/mindspore/lite/tools/common/graph_util.cc new file mode 100755 index 00000000000..1b840291748 --- /dev/null +++ b/mindspore/lite/tools/common/graph_util.cc @@ -0,0 +1,671 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/common/graph_util.h" +#include +#include +#include +#include +#include "schema/inner/model_generated.h" +#include "tools/common/tensor_util.h" +#include "tools/common/node_util.h" +#include "utils/log_adapter.h" +#include "src/common/utils.h" + +namespace mindspore { +namespace lite { +OpDefCopyer GetSimpleOpCopyer() { + return [](std::unique_ptr &inCNode) -> std::unique_ptr { + std::unique_ptr newCNode(new CNodeT); + + newCNode->name = inCNode->name; + newCNode->quantType = inCNode->quantType; + newCNode->primitive = std::make_unique(); + newCNode->primitive->value.type = inCNode->primitive->value.type; + // newCNode->quantParam.clear(); + // for (size_t i = 0; i < inCNode->quantParam.size(); i++) { + // auto &quantParam = inCNode->quantParam.at(i); + // auto quantParamCopy = CopyQuantParamArrayT(quantParam); + // if (quantParamCopy == nullptr) { + // //MS_LOG(ERROR)("CopyQuantParamArray return nullptr, node: %s", inOpDef->name.c_str()); + // return nullptr; + // } + // newCNode->quantParam.emplace_back(std::move(quantParamCopy)); + // } + return std::move(newCNode); + }; +} + +std::vector GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) { + return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx); +} + +std::vector GetInputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int inputIndexIdx) { + std::vector inputIndexes; + if (inputIndexIdx == -1) { + inputIndexes = node.inputIndex; + } else { + MS_ASSERT(node.inputIndex.size() > inputIndexIdx); + inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx)); + } + std::set inputNodeIdx; + for (uint32_t inputIdx : inputIndexes) { + auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx); + inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end()); + } + std::vector ret; + ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end()); + return ret; +} + +std::vector GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, + const int outputIndexIdx) { + return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx); +} + +std::vector GetOutputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int outputIndexIdx) { + std::vector outputIndexes; + if (outputIndexIdx == -1) { + outputIndexes = node.outputIndex; + } else { + MS_ASSERT(node.outputIndex.size() > outputIndexIdx); + outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx)); + } + std::set outputNodeIdx; + for (uint32_t outputIdx : outputIndexes) { + auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx); + outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end()); + } + std::vector ret; + ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end()); + return ret; +} + +std::vector GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) { + std::vector preNodeIdx; + for (size_t i = 0; i < graphT.nodes.size(); i++) { + auto &oldNode = graphT.nodes.at(i); + if (oldNode == nullptr) { + continue; + } + auto outputIndexes = oldNode->outputIndex; + if (IsContain(outputIndexes, tensorIdx)) { + preNodeIdx.emplace_back(i); + } + } + return std::move(preNodeIdx); +} + +std::vector GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) { + std::vector postNodeIdx; + for (size_t i = 0; i < graphT.nodes.size(); i++) { + auto &oldNode = graphT.nodes.at(i); + if (oldNode == nullptr) { + continue; + } + auto inputIndexes = oldNode->inputIndex; + if (IsContain(inputIndexes, tensorIdx)) { + postNodeIdx.emplace_back(i); + } + } + return std::move(postNodeIdx); +} + +STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) { + MS_ASSERT(graphT != nullptr); + MS_ASSERT(node != nullptr); + size_t nodeIdx = 0; + for (size_t i = 0; i < graphT->nodes.size(); i++) { + auto &inNode = graphT->nodes.at(i); + MS_ASSERT(inNode != nullptr); + if (inNode->name == node->name) { + nodeIdx = i; + break; + } + } + auto inputTensorIdxes = node->inputIndex; + auto outputTensorIdxes = node->outputIndex; + if (inputTensorIdxes.empty()) { + // MS_LOG(ERROR)("Node %s should has no inputs", node->name.c_str()); + return RET_ERROR; + } + if (outputTensorIdxes.size() != 1) { + // MS_LOG(ERROR)("FakeQuantNode %s should has 1 output, in fact: %zu", node->name.c_str(), + // outputTensorIdxes.size()); + return RET_ERROR; + } + auto inDataTensorIdx = inputTensorIdxes.front(); + auto outDataTensorIdx = outputTensorIdxes.front(); + + MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx); + const auto &inDataTensor = graphT->allTensors.at(inDataTensorIdx); + MS_ASSERT(inDataTensor != nullptr); + auto &gOutTensorIdx = graphT->outputIndex; + for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) { + if (*iter == outDataTensorIdx) { + *iter = inDataTensorIdx; + break; + } + } + + // find poseNode + auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0); + for (auto postNodeIdx : postNodeIdxes) { + MS_ASSERT(graphT->nodes.size() > postNodeIdx); + auto &postNode = graphT->nodes.at(postNodeIdx); + MS_ASSERT(postNode != nullptr); + for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) { + if (*iter == outDataTensorIdx) { + *iter = inDataTensorIdx; + break; + } + } + } + + // todo whether need to remove weightInputTensores + // remove all node's outputTensors + RemoveTensor(graphT, outputTensorIdxes); + node->inputIndex.clear(); + node->outputIndex.clear(); + + return RET_OK; +} + +STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) { + MS_ASSERT(graph != nullptr); + /* + if (graph->subgraphs.size() <= subGraphIdx) { + //MS_LOG(ERROR)("subGraphIdx out of range: %zu", subGraphIdx); + return RET_PARAM_INVALID; + } + */ + // return IsolateOneWayNode(graph->subgraphs.at(subGraphIdx).get(), nodeIdx, removeTensor); + return IsolateOneWayNode(graph, nodeIdx, removeTensor); +} + +STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) { + MS_ASSERT(graphT != nullptr); + if (graphT->nodes.size() <= nodeIdx) { + // MS_LOG(ERROR)("nodeIdx out of range: %zu", nodeIdx); + return RET_PARAM_INVALID; + } + + CNodeT *node = graphT->nodes.at(nodeIdx).get(); + auto inputTensorIdxes = node->inputIndex; + auto outputTensorIdxes = node->outputIndex; + auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx); + if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) { + // MS_LOG(ERROR)("Only support node who has no more than one input and one output"); + return RET_ERROR; + } + if (inputTensorIdxes.empty()) { + // MS_LOG(ERROR)("Error, %zuth node has no input tensor", nodeIdx); + return RET_ERROR; + } + auto inDataTensorIdx = inputTensorIdxes.front(); + if (!outputTensorIdxes.empty()) { + auto outDataTensorIdx = outputTensorIdxes.front(); + MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx); + MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr); + auto &gOutTensorIdx = graphT->outputIndex; + for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) { + if (*iter == outDataTensorIdx) { + *iter = inDataTensorIdx; + break; + } + } + // find poseNode + auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0); + for (auto postNodeIdx : postNodeIdxes) { + MS_ASSERT(graphT->nodes.size() > postNodeIdx); + auto &postNode = graphT->nodes.at(postNodeIdx); + MS_ASSERT(postNode != nullptr); + for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) { + if (*iter == outDataTensorIdx) { + *iter = inDataTensorIdx; + break; + } + } + } + } + + if (removeTensor) { + // now all node's outputTensors are useless + // remove all node's outputTensors + auto status = RemoveTensor(graphT, outputTensorIdxes); + if (status != RET_OK) { + // MS_LOG(ERROR)("RemoveOutputTensors of node %s failed", node->name.c_str()); + return RET_ERROR; + } + } + node->inputIndex.clear(); + node->outputIndex.clear(); + return RET_OK; +} + +STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTensor) { + MS_ASSERT(graphT != nullptr); + MS_ASSERT(node != nullptr); + bool isSubNode = false; + size_t nodeIdx = 0; + for (size_t i = 0; i < graphT->nodes.size(); i++) { + auto &inNode = graphT->nodes.at(i); + if (inNode->name == node->name) { + isSubNode = true; + nodeIdx = i; + break; + } + } + if (!isSubNode) { + // MS_LOG(ERROR)("Node %s is not in graphT %s", node->name.c_str(), graphT->name.c_str()); + return RET_PARAM_INVALID; + } else { + return IsolateOneWayNode(graphT, nodeIdx, removeTensor); + } +} + +STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector toDeleteTensorIdxes, bool forceDelete) { + for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) { + uint32_t deleteIdx = *iter; + if (!forceDelete) { + if (GetRefCount(graphT, deleteIdx) > 1) { + iter++; + continue; + } + } + // update graph input indexes + for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) { + if (*gInIdx > deleteIdx) { + (*gInIdx)--; + } + } + // update graph output indexes + for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) { + if (*gOutIdx > deleteIdx) { + (*gOutIdx)--; + } + } + // update nodes indexes + for (auto nodeIter = graphT->nodes.begin(); nodeIter != graphT->nodes.end(); nodeIter++) { + // update nodes input indexes + UpdateNodeIndex((*nodeIter).get(), deleteIdx); + } + // update deleteTensorIdx + for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) { + if (*selfIt > deleteIdx) { + (*selfIt)--; + } + } + graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx); + iter = toDeleteTensorIdxes.erase(iter); + } + return RET_OK; +} + +STATUS UpdateNodeIndex(CNodeT *node, uint32_t deleteIdx) { + for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) { + if (*inIdxIt == deleteIdx) { + inIdxIt = node->inputIndex.erase(inIdxIt); + } else { + if (*inIdxIt > deleteIdx) { + (*inIdxIt)--; + } + inIdxIt++; + } + } + // update nodes output indexes + for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) { + if (*outIdxIt == deleteIdx) { + outIdxIt = node->outputIndex.erase(outIdxIt); + } else { + if (*outIdxIt > deleteIdx) { + (*outIdxIt)--; + } + outIdxIt++; + } + } + return RET_OK; +} + +STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr tensor, + InsertPlace place) { + if (nodeIdx >= graphT->nodes.size()) { + // MS_LOG(ERROR)("nodeIdx out of range: %du", nodeIdx); + return RET_PARAM_INVALID; + } + graphT->allTensors.emplace_back(std::move(tensor)); + uint32_t newTensorIdx = graphT->allTensors.size() - 1; + auto node = graphT->nodes.at(nodeIdx).get(); + if (place == kBefore) { + node->inputIndex.emplace_back(newTensorIdx); + } else { + node->outputIndex.emplace_back(newTensorIdx); + } + return RET_OK; +} + +STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_t inTensorIdx, + std::unique_ptr tensor) { + if (nodeIdx >= graphT->nodes.size()) { + // MS_LOG(ERROR)("nodeIdx out of range: %du", nodeIdx); + return RET_PARAM_INVALID; + } + auto node = graphT->nodes.at(nodeIdx).get(); + if (inTensorIdx >= graphT->allTensors.size()) { + // MS_LOG(ERROR)("inTensorIdx out of range: %du", nodeIdx); + return RET_PARAM_INVALID; + } + if (!IsContain(node->inputIndex, inTensorIdx)) { + // MS_LOG(ERROR)("inTensorIdx(%du) is not a inputIdx of node(%du)", inTensorIdx, nodeIdx); + return RET_PARAM_INVALID; + } + graphT->allTensors.at(inTensorIdx).swap(tensor); + return RET_OK; +} + +NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex, + std::unique_ptr toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer) { + if (existNodeIdx >= graphT->nodes.size()) { + // MS_LOG(ERROR)("nodeIdx out of range: %du", existNodeIdx); + return graphT->nodes.end(); + } + auto nodeIter = graphT->nodes.begin() + existNodeIdx; + MS_ASSERT(nodeIter != graphT->nodes.begin()); + MS_ASSERT((*nodeIter) != nullptr); + return InsertNode(graphT, nodeIter, place, inoutIndex, std::move(toAddNode), errorCode); +} + +NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx, + std::unique_ptr toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer) { + if (place == kBefore) { + return InsertNodeBefore(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, opDefCopyer); + } else if (place == kAfter) { + return InsertNodeAfter(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, opDefCopyer); + } else { + // MS_LOG(ERROR)("Invalid InsertPlace : %d", place); + return graphT->nodes.end(); + } +} + +NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx, + std::unique_ptr toAddNodeIn, STATUS *errorCode, OpDefCopyer opDefCopyer) { + auto &existNode = *existNodeIter; + MS_ASSERT(existNode != nullptr); + MS_ASSERT(existNode->inputIndex.size() > inputIndexIdx); + MS_ASSERT(toAddNodeIn != nullptr); + auto preTensorIdx = existNode->inputIndex.at(inputIndexIdx); + MS_ASSERT(graphT->allTensors.size() > preTensorIdx); + + auto preNodeIdxes = GetInputNodeIdx(*graphT, *(existNode.get()), inputIndexIdx); + if (preNodeIdxes.empty()) { + auto &preTensor = graphT->allTensors.at(preTensorIdx); + MS_ASSERT(preTensor != nullptr); + auto toAddTensor = CopyTensorDefT(preTensor); + if (toAddTensor == nullptr) { + MS_LOG(ERROR) << "Copy TensorT failed"; + *errorCode = RET_NULL_PTR; + return graphT->nodes.end(); + } + preTensor->refCount = 0; + preTensor->data.clear(); + graphT->allTensors.emplace_back(std::move(toAddTensor)); + size_t toAddTensorIdx = graphT->allTensors.size() - 1; + auto toAddNode = opDefCopyer(toAddNodeIn); + if (toAddNode == nullptr) { + MS_LOG(ERROR) << "copy toAddNodeIn failed"; + *errorCode = RET_NULL_PTR; + return graphT->nodes.end(); + } + toAddNode->inputIndex.clear(); + toAddNode->inputIndex.push_back(toAddTensorIdx); + toAddNode->outputIndex.clear(); + toAddNode->outputIndex.push_back(preTensorIdx); + for (auto iter = graphT->inputIndex.begin(); iter != graphT->inputIndex.end(); iter++) { + if (*iter == preTensorIdx) { + *iter = toAddTensorIdx; + break; + } + } + existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode)); + existNodeIter++; + } else { + std::vector> toAddNodes; + int i = 0; + for (size_t preNodeIdx : preNodeIdxes) { + MS_ASSERT(graphT->nodes.size() > preNodeIdx); + auto &preNode = graphT->nodes.at(preNodeIdx); + MS_ASSERT(preNode != nullptr); + auto &preTensor = graphT->allTensors.at(preTensorIdx); + MS_ASSERT(preTensor != nullptr); + auto toAddTensor = CopyTensorDefT(preTensor); + if (toAddTensor == nullptr) { + *errorCode = RET_NULL_PTR; + // MS_LOG(ERROR)("Copy TensorT failed"); + return graphT->nodes.end(); + } + graphT->allTensors.emplace_back(std::move(toAddTensor)); + size_t toAddTensorIdx = graphT->allTensors.size() - 1; + auto toAddNode = opDefCopyer(toAddNodeIn); + if (toAddNode == nullptr) { + // MS_LOG(ERROR)("copy toAddNodeIn failed"); + *errorCode = RET_NULL_PTR; + return graphT->nodes.end(); + } + toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i++); + toAddNode->inputIndex.clear(); + toAddNode->inputIndex.push_back(preTensorIdx); + toAddNode->outputIndex.clear(); + toAddNode->outputIndex.push_back(toAddTensorIdx); + for (auto iter = existNode->inputIndex.begin(); iter != existNode->inputIndex.end(); iter++) { + if (*iter == preTensorIdx) { + *iter = toAddTensorIdx; + break; + } + } + toAddNodes.emplace_back(std::move(toAddNode)); + } + for (auto &toAddNode : toAddNodes) { + existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode)); + existNodeIter++; + } + } + *errorCode = RET_OK; + return existNodeIter; +} + +NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx, + std::unique_ptr toAddNodeIn, STATUS *errorCode, OpDefCopyer opDefCopyer) { + auto &existNode = *existNodeIter; + MS_ASSERT(existNode != nullptr); + MS_ASSERT(existNode->outputIndex.size() > outputIndexIdx); + MS_ASSERT(toAddNodeIn != nullptr); + auto postTensorIdx = existNode->outputIndex.at(outputIndexIdx); + MS_ASSERT(graphT->allTensors.size() > postTensorIdx); + + auto postNodeIdxes = GetOutputNodeIdx(*graphT, *(existNode.get()), outputIndexIdx); + if (postNodeIdxes.empty()) { + auto &postTensor = graphT->allTensors.at(postTensorIdx); + MS_ASSERT(postTensor != nullptr); + auto toAddTensor = CopyTensorDefT(postTensor); + if (toAddTensor == nullptr) { + // MS_LOG(ERROR)("Copy TensorT failed"); + *errorCode = RET_NULL_PTR; + return graphT->nodes.end(); + } + graphT->allTensors.emplace_back(std::move(toAddTensor)); + size_t toAddTensorIdx = graphT->allTensors.size() - 1; + auto toAddNode = opDefCopyer(toAddNodeIn); + if (toAddNode == nullptr) { + // MS_LOG(ERROR)("copy toAddNodeIn failed"); + *errorCode = RET_NULL_PTR; + return graphT->nodes.end(); + } + toAddNode->inputIndex.clear(); + toAddNode->inputIndex.push_back(postTensorIdx); + toAddNode->outputIndex.clear(); + toAddNode->outputIndex.push_back(toAddTensorIdx); + for (auto iter = graphT->outputIndex.begin(); iter != graphT->outputIndex.end(); iter++) { + if (*iter == postTensorIdx) { + *iter = toAddTensorIdx; + break; + } + } + existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode)); + existNodeIter++; + } else { + std::vector> toAddNodes; + int i = 0; + for (size_t postNodeIdx : postNodeIdxes) { + MS_ASSERT(graphT->nodes.size() > postNodeIdx); + auto &postNode = graphT->nodes.at(postNodeIdx); + MS_ASSERT(postNode != nullptr); + auto &postTensor = graphT->allTensors.at(postTensorIdx); + MS_ASSERT(postTensor != nullptr); + auto toAddTensor = CopyTensorDefT(postTensor); + if (toAddTensor == nullptr) { + // MS_LOG(ERROR)("Copy TensorT failed"); + *errorCode = RET_NULL_PTR; + return graphT->nodes.end(); + } + graphT->allTensors.emplace_back(std::move(toAddTensor)); + size_t toAddTensorIdx = graphT->allTensors.size() - 1; + auto toAddNode = opDefCopyer(toAddNodeIn); + if (toAddNode == nullptr) { + // MS_LOG(ERROR)("copy toAddNodeIn failed"); + *errorCode = RET_NULL_PTR; + return graphT->nodes.end(); + } + toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i++); + toAddNode->inputIndex.clear(); + toAddNode->inputIndex.push_back(postTensorIdx); + toAddNode->outputIndex.clear(); + toAddNode->outputIndex.push_back(toAddTensorIdx); + MS_ASSERT(IsContain(postNode->inputIndex, postTensorIdx)); + for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) { + if (*iter == postTensorIdx) { + *iter = toAddTensorIdx; + break; + } + } + toAddNodes.emplace_back(std::move(toAddNode)); + } + for (auto &toAddNode : toAddNodes) { + existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode)); + existNodeIter++; + } + } + *errorCode = RET_OK; + return existNodeIter; +} + +STATUS ValidateFileStr(const std::string &modelFile, std::string fileType) { + if (modelFile.size() > fileType.size()) { + if (modelFile.substr(modelFile.size() - fileType.size()) == fileType) { + return RET_OK; + } else { + return RET_ERROR; + } + } else { + return RET_ERROR; + } +} + +std::string GetModelName(const std::string &modelFile) { + std::string modelName = modelFile; + modelName = modelName.substr(modelName.find_last_of('/') + 1); + modelName = modelName.substr(0, modelName.find_last_of('.')); + + srand((unsigned)time(NULL)); + modelName = modelName + std::to_string(rand()); + + return modelName; +} + +OpGraphT *OpGraphT::Build(const schema::MetaGraphT *subGraphDef) { + if (subGraphDef == nullptr) { + // MS_LOG(ERROR)("subGraphDef is nullptr"); + return nullptr; + } + auto graph = std::unique_ptr(new OpGraphT()); + if (graph == nullptr) { + // MS_LOG(ERROR)("malloc opgraph failed"); + return nullptr; + } + + auto &opDefs = subGraphDef->nodes; + + for (auto &opDef : opDefs) { + auto ret = graph->AddEdge(opDef.get(), &opDefs); + if (ret != RET_OK) { + // MS_LOG(ERROR)("%s add edge failed. ret:%d", opDef->name.c_str(), ret); + return nullptr; + } + } + + return graph.release(); +} + +int OpGraphT::AddEdge(const schema::CNodeT *srcNodeDef, const std::vector> *nodeDefs) { + MS_ASSERT(srcNodeDef != nullptr); + MS_ASSERT(nodeDefs != nullptr); + NODE_ID srcId = std::string(srcNodeDef->name); + // for single op condition + AddNode(srcId); + for (auto index : srcNodeDef->outputIndex) { + for (auto &dstNodeDef : *nodeDefs) { + bool find = false; + auto inputIndex = dstNodeDef->inputIndex; + if (std::any_of(inputIndex.begin(), inputIndex.end(), [&index](int i) { return i == index; })) { + find = true; + } + + if (!find) { + continue; + } + NODE_ID dstId = std::string(dstNodeDef->name.c_str()); + auto ret = AddEdge(srcId, dstId); + if (ret != RET_OK) { + return ret; + } + } + } + return RET_OK; +} + +int OpGraphT::AddEdge(NODE_ID srcId, NODE_ID dstId) { + auto srcNode = AddNode(srcId); + if (srcNode == nullptr) { + // MS_LOG(ERROR)("add srcNode failed"); + return RET_ERROR; + } + srcNode->AddOutEdge(dstId); + auto dstNode = AddNode(dstId); + if (dstNode == nullptr) { + // MS_LOG(ERROR)("add dstNode failed"); + return RET_ERROR; + } + dstNode->AddInEdge(srcId); + return RET_OK; +} + +OpGraphT::~OpGraphT() { + for (auto iter : nodes) { + delete iter.second; + } + nodes.clear(); +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/common/graph_util.h b/mindspore/lite/tools/common/graph_util.h new file mode 100644 index 00000000000..818c53502bb --- /dev/null +++ b/mindspore/lite/tools/common/graph_util.h @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_GRAPH_UTIL_H +#define MINDSPORE_PREDICT_GRAPH_UTIL_H + +#include +#include +#include +#include +#include +#include + +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" +#include "utils/log_adapter.h" +#include "src/common/graph_util.h" + +namespace mindspore { +namespace lite { +using STATUS = int; +enum InsertPlace { kBefore, kAfter }; + +using NodeIter = std::vector>::iterator; + +using OpDefCopyer = std::function(std::unique_ptr &)>; + +OpDefCopyer GetSimpleOpCopyer(); + +std::vector GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int inputIndexIdx = -1); + +std::vector GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node, + int inputIndexIdx = -1); + +std::vector GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int outputIndexIdx = -1); + +std::vector GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node, + int outputIndexIdx = -1); + +std::vector GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx); + +std::vector GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx); + +STATUS IsolateNode(schema::MetaGraphT *subGraph, schema::CNodeT *node); + +STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor = true); + +STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t subGraphIdx, size_t nodeIdx, bool removeTensor = true); + +STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor = true); + +STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx); + +STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector toDeleteTensorIdxes, bool forceDelete = false); + +STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr tensor, + InsertPlace place = kBefore); + +STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_t inTensorIdx, + std::unique_ptr tensor); + +NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex, + std::unique_ptr toAddNode, STATUS *errorCode, + OpDefCopyer opDefCopyer = GetSimpleOpCopyer()); + +NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx, + std::unique_ptr toAddNode, STATUS *errorCode, + OpDefCopyer opDefCopyer = GetSimpleOpCopyer()); + +NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx, + std::unique_ptr toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer); + +NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx, + std::unique_ptr toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer); + +STATUS ValidateFileStr(const std::string &modelFile, std::string fileType); +std::string GetModelName(const std::string &modelFile); + +class OpGraphT : public OpGraph { + public: + OpGraphT() {} + ~OpGraphT(); + static OpGraphT *Build(const schema::MetaGraphT *subGraphDef); + + private: + int AddEdge(NODE_ID srcId, NODE_ID dstId); + int AddEdge(const schema::CNodeT *srcNodeDef, const std::vector> *nodeDefs); +}; + +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_GRAPH_UTIL_H + diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc new file mode 100644 index 00000000000..2f3312a8cb7 --- /dev/null +++ b/mindspore/lite/tools/common/node_util.cc @@ -0,0 +1,178 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/common/node_util.h" +#include +#include +#include "src/common/common.h" +#include "utils/log_adapter.h" +#include "tools/common/graph_util.h" +#include "tools/common/tensor_util.h" + +namespace mindspore { +namespace lite { +STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr &node) { + MS_ASSERT(graphT != nullptr); + MS_ASSERT(node != nullptr); + // set quantParam to preNode + for (size_t i = 0; i < node->inputIndex.size(); i++) { + auto preNodeIdexes = GetInputNodeIdx(*graphT, *(node.get()), i); + for (auto preNodeIdx : preNodeIdexes) { + MS_ASSERT(graphT->nodes.size() > preNodeIdx); + auto &preNode = graphT->nodes.at(preNodeIdx); + MS_ASSERT(preNode != nullptr); + // if preNode is not init, it maybe not a quantNode, so skip + // if (preNode->inputIndex.size() + preNode->outputIndex.size() != preNode->quantParam.size()) { + // continue; + // } + auto preNodeOutputIndexes = preNode->outputIndex; + int32_t currentNodeIndexInPre = -1; + for (auto index : preNodeOutputIndexes) { + currentNodeIndexInPre++; + if (index == node->inputIndex.at(i)) { + break; + } + } + MS_ASSERT(currentNodeIndexInPre != -1); + MS_ASSERT(node->quantParam.size() > i); + MS_ASSERT(node->quantParam.at(i) != nullptr); + // auto quantParamArrayCopy = CopyQuantParamArrayT(node->quantParam.at(i)); + // if (quantParamArrayCopy == nullptr) { + // //MS_LOG(ERROR)("CopyQuantParamArray return nullptr, node: %s", node->name.c_str()); + // return RET_ERROR; + // } + // preNode->quantParam.at(preNode->inputIndex.size() + currentNodeIndexInPre) = + // std::move(CopyQuantParamArrayT(quantParamArrayCopy)); + } + } + + // set quantParam to postNode + for (size_t i = 0; i < node->outputIndex.size(); i++) { + auto postNodeIdexes = GetOutputNodeIdx(*graphT, *(node.get()), i); + for (auto postNodeIdx : postNodeIdexes) { + MS_ASSERT(graphT->nodes.size() > postNodeIdx); + auto &postNode = graphT->nodes.at(postNodeIdx); + MS_ASSERT(postNode != nullptr); + // if postNode is not init, it maybe not a quantNode, so skip + // if (postNode->inputIndex.size() + postNode->outputIndex.size() != postNode->quantParam.size()) { + // continue; + // } + auto postNodeInputIndexes = postNode->inputIndex; + int32_t currentNodeIndexInPost = -1; + for (auto index : postNodeInputIndexes) { + currentNodeIndexInPost++; + if (index == node->outputIndex.at(i)) { + break; + } + } + MS_ASSERT(currentNodeIndexInPost != -1); + MS_ASSERT(node->quantParam.size() > node->inputIndex.size() + i); + MS_ASSERT(node->quantParam.at(node->inputIndex.size() + i) != nullptr); + // auto quantParamArrayCopy = CopyQuantParamArrayT(node->quantParam.at(node->inputIndex.size() + i)); + // if (quantParamArrayCopy == nullptr) { + // //MS_LOG(ERROR)("CopyQuantParamArray return nullptr, node: %s", node->name.c_str()); + // return RET_ERROR; + // } + // postNode->quantParam.at(currentNodeIndexInPost) = std::move(CopyQuantParamArrayT(quantParamArrayCopy)); + } + } + return RET_OK; +} + +static const std::vector nhwcOpList = { + schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, + schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D, + schema::PrimitiveType_Pooling, schema::PrimitiveType_Resize, + schema::PrimitiveType_FusedBatchNorm}; + +static const std::vector fp32FullOpList = { + schema::PrimitiveType_Concat, schema::PrimitiveType_Add, + schema::PrimitiveType_Floor}; // fp32 ops support C4 and nhwc in fp32 + +static const std::vector uint8NeedNhwcOpList = {}; + +static const std::vector uint8OpList = { + schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, + schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, + schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax, schema::PrimitiveType_Reshape, + schema::PrimitiveType_Activation}; + +std::vector Getfp32FullOpList() { return fp32FullOpList; } + +std::vector GetNhwcOpList() { return nhwcOpList; } + +std::vector GetUint8NhwcOpList() { return uint8NeedNhwcOpList; } + +std::vector GetUint8OpList() { return uint8OpList; } + +STATUS NodeUtils::ConvertDims(mindspore::lite::Format src_format, const std::vector &src_dims, + mindspore::lite::Format dst_format, std::vector *dst_dims) { + if ((src_dims.size() != DIM_DEFAULT_SIZE && src_dims.size() != 3) || src_format == dst_format) { + // MS_LOG(ERROR)("Convert format , src size %lu <3 or src format is equal to dst format,not need convert", + // src_dims.size()); + *dst_dims = src_dims; + return RET_PARAM_INVALID; + } + + std::vector nchw_dim; + switch (src_format) { + case Format_NCHW: + nchw_dim = src_dims; + break; + case Format_NHWC: + if (src_dims.size() == DIM_DEFAULT_SIZE) { + nchw_dim.push_back(src_dims[NHWC_N]); + nchw_dim.push_back(src_dims[NHWC_C]); + nchw_dim.push_back(src_dims[NHWC_H]); + nchw_dim.push_back(src_dims[NHWC_W]); + } else { + nchw_dim.push_back(src_dims[HWC_C]); + nchw_dim.push_back(src_dims[HWC_H]); + nchw_dim.push_back(src_dims[HWC_W]); + } + break; + default: + // MS_LOG(ERROR)("Not support src format: %d", src_format); + return RET_ERROR; + } + + if (nchw_dim.size() == 0) { + // MS_LOG(ERROR)("Param nchw_dim is empty!"); + return RET_ERROR; + } + + switch (dst_format) { + case Format_NCHW: + *dst_dims = nchw_dim; + break; + case Format_NHWC: + if (src_dims.size() == DIM_DEFAULT_SIZE) { + dst_dims->push_back(nchw_dim[NCHW_N]); + dst_dims->push_back(nchw_dim[NCHW_H]); + dst_dims->push_back(nchw_dim[NCHW_W]); + dst_dims->push_back(nchw_dim[NCHW_C]); + } + break; + default: + // MS_LOG(ERROR)("Not support dst format: %d", dst_format); + return RET_ERROR; + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h new file mode 100644 index 00000000000..43e9aef5587 --- /dev/null +++ b/mindspore/lite/tools/common/node_util.h @@ -0,0 +1,373 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_NODE_UTIL_H +#define MINDSPORE_PREDICT_NODE_UTIL_H + +#include +#include +#include "schema/inner/model_generated.h" +#include "src/common/common.h" +#include "utils/log_adapter.h" +#include "include/errorcode.h" +#include "securec/include/securec.h" + +namespace mindspore { +namespace lite { +using STATUS = int; +STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr &node); + +std::vector GetNhwcOpList(); + +std::vector Getfp32FullOpList(); + +std::vector GetUint8NhwcOpList(); + +std::vector GetUint8OpList(); + +class NodeUtils { + public: + static STATUS ConvertDims(schema::Format src_format, const std::vector &src_dims, schema::Format dst_format, + std::vector *dst_dims); + + static void SliceData(std::vector &input, int64_t chunk_size, std::vector &output, int64_t begin, + int64_t out_dim, int64_t stride); + + static STATUS SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector &input_dims, + std::vector &begin, std::vector &output_dims, + schema::TensorT *output, std::vector &stride); +}; + +// todo check this +enum kTransFilterType { + kKCHW2HWCK, + kKCHW2KHWC, + kCKHW2KHWC, + kCKHW2HWCK, + kKCHW2HWKC, + kCKHW2HWKC, + kHWCK2KCHW, + kHWCK2CKHW, + kHWKC2KCHW, + kHWKC2CKHW, + kNHWC2KCHW, + kNHWC2CKHW, + kNHWC2HWCK, + kKHWC2HWCK, + kCHWK2HWCK, + kKHWC2CHWK, + kCHWK2KHWC +}; + +static STATUS GetFilterDim(std::vector &oriDims, kTransFilterType type, int32_t &filterK, int32_t &filterC, + int32_t &filterH, int32_t &filterW) { + MS_ASSERT(oriDims.size() == 4); + if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC) { + filterK = oriDims.at(KCHW_K); + filterC = oriDims.at(KCHW_C); + filterH = oriDims.at(KCHW_H); + filterW = oriDims.at(KCHW_W); + } else if (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC) { + filterC = oriDims.at(CKHW_C); + filterK = oriDims.at(CKHW_K); + filterH = oriDims.at(CKHW_H); + filterW = oriDims.at(CKHW_W); + } else if (type == kHWCK2KCHW || type == kHWCK2CKHW) { + filterH = oriDims.at(HWCK_H); + filterW = oriDims.at(HWCK_W); + filterC = oriDims.at(HWCK_C); + filterK = oriDims.at(HWCK_K); + } else if (type == kHWKC2KCHW || type == kHWKC2CKHW) { + filterH = oriDims.at(HWKC_H); + filterW = oriDims.at(HWKC_W); + filterK = oriDims.at(HWKC_K); + filterC = oriDims.at(HWKC_C); + } else if (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW) { + filterK = oriDims.at(NHWC_N); + filterH = oriDims.at(NHWC_H); + filterW = oriDims.at(NHWC_W); + filterC = oriDims.at(NHWC_C); + } else if (type == kCHWK2HWCK || type == kCHWK2KHWC) { + filterC = oriDims.at(CHWK_C); + filterH = oriDims.at(CHWK_H); + filterW = oriDims.at(CHWK_W); + filterK = oriDims.at(CHWK_K); + } else if (type == kKHWC2HWCK || type == kKHWC2CHWK) { + filterK = oriDims.at(KHWC_K); + filterH = oriDims.at(KHWC_H); + filterW = oriDims.at(KHWC_W); + filterC = oriDims.at(KHWC_C); + } else { + MS_LOG(ERROR) << "Unsupported transFilterType: " << type; + return RET_ERROR; + } + return RET_OK; +} + +static STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, + int32_t filterH, int32_t filterW) { + MS_ASSERT(tensor != nullptr); + if (type == kKCHW2HWCK || type == kCKHW2HWCK || type == kNHWC2HWCK || type == kKHWC2HWCK || type == kCHWK2HWCK) { + tensor->dims = {filterH, filterW, filterC, filterK}; + } else if (type == kKCHW2HWKC || type == kCKHW2HWKC) { + tensor->dims = {filterH, filterW, filterK, filterC}; + } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) { + tensor->dims = {filterK, filterC, filterH, filterW}; + } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW) { + tensor->dims = {filterC, filterK, filterH, filterW}; + } else if (type == kKHWC2CHWK) { + tensor->dims = {filterC, filterH, filterW, filterK}; + } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) { + tensor->dims = {filterK, filterH, filterW, filterC}; + } else { + MS_LOG(ERROR) << "Unsupported transFilterType: " << type; + return RET_ERROR; + } + return RET_OK; +} + +template +static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, + int32_t filterH, int32_t filterW) { + MS_ASSERT(tensor != nullptr); + int count = filterH * filterW * filterC * filterK; + if (count <= 0) { + MS_LOG(ERROR) << "Dim size invalid"; + return RET_ERROR; + } + std::unique_ptr buf(new (std::nothrow) T[count]); + if (buf == nullptr) { + MS_LOG(ERROR) << "new buf failed"; + return RET_ERROR; + } + + void *originWeightDate = tensor->data.data(); + T *weightData = static_cast(originWeightDate); + + if (weightData == nullptr) { + MS_LOG(ERROR) << "weightData is nullptr"; + return RET_ERROR; + } + T *p1Buff = nullptr; + T *p2Buff = nullptr; + switch (type) { + case kCHWK2HWCK: + case kCHWK2KHWC: { + for (int c = 0; c < filterC; ++c) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int k = 0; k < filterK; ++k) { + p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); + if (type == kCHWK2HWCK) { + p2Buff = + buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kCHWK2KHWC) { + p2Buff = + buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + } + *p2Buff = *p1Buff; + } + } + } + } + } break; + case kKHWC2HWCK: { + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + *p2Buff = *p1Buff; + } + } + } + } + } break; + case kKCHW2HWCK: + case kKCHW2KHWC: + case kKCHW2HWKC: { + for (int k = 0; k < filterK; ++k) { + for (int c = 0; c < filterC; ++c) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + if (type == kKCHW2HWCK) { + p2Buff = + buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kKCHW2KHWC) { + p2Buff = + buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + } else { + p2Buff = + buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); + } + *p2Buff = *p1Buff; + } + } + } + } + } break; + case kCKHW2HWCK: + case kCKHW2KHWC: + case kCKHW2HWKC: { + for (int c = 0; c < filterC; ++c) { + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + if (type == kCKHW2HWCK) { + p2Buff = + buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kKCHW2KHWC) { + p2Buff = + buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterK) + (w * filterC) + (c)); + } else { + p2Buff = + buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); + } + *p2Buff = *p1Buff; + } + } + } + } + } break; + case kHWCK2KCHW: + case kHWCK2CKHW: { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + for (int k = 0; k < filterK; ++k) { + p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + if (type == kHWCK2KCHW) { + p2Buff = + buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = + buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } + *p2Buff = *p1Buff; + } + } + } + } + } break; + case kHWKC2KCHW: + case kHWKC2CKHW: { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + for (int k = 0; k < filterK; ++k) { + p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); + if (type == kHWKC2KCHW) { + p2Buff = + buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = + buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } + *p2Buff = *p1Buff; + } + } + } + } + } break; + case kNHWC2HWCK: + case kNHWC2KCHW: + case kNHWC2CKHW: { + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); + if (type == kNHWC2HWCK) { + p2Buff = + buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kNHWC2CKHW) { + p2Buff = + buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = + buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + } + *p2Buff = *p1Buff; + } + } + } + } + } break; + case kKHWC2CHWK: { + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k)); + *p2Buff = *p1Buff; + } + } + } + } + } break; + default: { + MS_LOG(ERROR) << "Unsupported transFilterType: " << type; + return RET_ERROR; + } + } + + auto ret = ::memcpy_s(tensor->data.data(), count * sizeof(T), buf.get(), count * sizeof(T)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s failed: " << ret; + return RET_ERROR; + } + return RET_OK; +} + +template +static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type) { + MS_ASSERT(tensor != nullptr); + std::vector oriDims = tensor->dims; + if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) { + MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size(); + return RET_ERROR; + } + + int32_t filterH; + int32_t filterW; + int32_t filterC; + int32_t filterK; + auto status = GetFilterDim(oriDims, type, filterK, filterC, filterH, filterW); + if (status != RET_OK) { + MS_LOG(ERROR) << "GetFilterDim failed: " << status; + return status; + } + status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW); + if (status != RET_OK) { + MS_LOG(ERROR) << "SetFilterDim failed: " << status; + return status; + } + status = TransFilterData(tensor, type, filterK, filterC, filterH, filterW); + if (status != RET_OK) { + MS_LOG(ERROR) << "TransFilterData failed: " << status; + return status; + } + + return RET_OK; +} +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_PREDICT_NODE_UTIL_H + diff --git a/mindspore/lite/tools/common/option.h b/mindspore/lite/tools/common/option.h new file mode 100644 index 00000000000..8b323b73363 --- /dev/null +++ b/mindspore/lite/tools/common/option.h @@ -0,0 +1,120 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_COMMON_OPTION_H_ +#define PREDICT_COMMON_OPTION_H_ + +#include +#include +#include "utils/log_adapter.h" + +namespace mindspore { +namespace lite { +template +struct InnerSome { + explicit InnerSome(const T &t) : _t(std::move(t)) {} + + T _t; +}; + +template +InnerSome::type> Some(T &&t) { + return InnerSome::type>(std::forward(t)); +} + +struct None {}; + +template +class Option { + public: + Option() : state(NONE) {} + + explicit Option(const T &t) : data(t), state(SOME) {} + + explicit Option(T &&t) : data(std::move(t)), state(SOME) {} + + explicit Option(const InnerSome &some) : data(some._t), state(SOME) {} + + explicit Option(const None &none) : state(NONE) {} + + Option(const Option &that) : state(that.state) { + if (that.IsSome()) { + new (&data) T(that.data); + } + } + + virtual ~Option() {} + + bool IsNone() const { return state == NONE; } + + bool IsSome() const { return state == SOME; } + + const T &Get() const & { + MS_ASSERT(IsSome()); + return data; + } + + T &Get() & { + MS_ASSERT(IsSome()); + return data; + } + + T &&Get() && { + MS_ASSERT(IsSome()); + return std::move(data); + } + + const T &&Get() const && { + MS_ASSERT(IsSome()); + return std::move(data); + } + + // oprerator override + Option &operator=(const Option &that) { + if (&that != this) { + if (IsSome()) { + data.~T(); + } + state = that.state; + if (that.IsSome()) { + new (&data) T(that.data); + } + } + + return *this; + } + + bool operator==(const Option &that) const { + return (IsNone() && that.IsNone()) || (IsSome() && that.IsSome() && data == that.data); + } + + bool operator!=(const Option &that) const { return !(*this == that); } + + bool operator==(const T &that) const { return IsSome() && data == that; } + + bool operator!=(const T &that) const { return !(*this == that); } + + private: + enum State { NONE = 0, SOME = 1 }; + + T data; + State state; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_COMMON_OPTION_H_ + diff --git a/mindspore/lite/tools/common/storage.cc b/mindspore/lite/tools/common/storage.cc new file mode 100755 index 00000000000..dac50715025 --- /dev/null +++ b/mindspore/lite/tools/common/storage.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/common/storage.h" +#include "flatbuffers/flatbuffers.h" +#include "utils/log_adapter.h" +#include "src/common/file_utils.h" + +namespace mindspore { +namespace lite { +int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath) { + flatbuffers::FlatBufferBuilder builder(1024); + auto offset = schema::MetaGraph::Pack(builder, &graph); + builder.Finish(offset); + int size = builder.GetSize(); + auto content = builder.GetBufferPointer(); + if (content == nullptr) { + MS_LOG(ERROR) << "GetBufferPointer nullptr"; + return RET_ERROR; + } + + std::ofstream output(outputPath + ".ms", std::ofstream::binary); + if (!output.is_open()) { + MS_LOG(ERROR) << "ofstream open failed"; + return RET_ERROR; + } + + output.write((const char *)content, size); + output.close(); + return RET_OK; +} + +schema::MetaGraphT *Storage::Load(const std::string &inputPath) { + size_t size; + auto buf = ReadFile(inputPath.c_str(), &size); + if (buf == nullptr) { + // MS_LOG(ERROR)("the file buffer is nullptr"); + return nullptr; + } + + flatbuffers::Verifier verify((const uint8_t *)buf, size); + // if (false == VerifyGraphDefBuffer(verify)) { + // //MS_LOG(ERROR)("the buffer is invalid and fail to create graph"); + // return nullptr; + // } + + auto graphDefT = schema::UnPackMetaGraph(buf); + return graphDefT.release(); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/common/storage.h b/mindspore/lite/tools/common/storage.h new file mode 100644 index 00000000000..c1cdfa27adb --- /dev/null +++ b/mindspore/lite/tools/common/storage.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_COMMON_STORAGE_H_ +#define PREDICT_COMMON_STORAGE_H_ + +#include +#include +#include "include/errorcode.h" +#include "flatbuffers/flatbuffers.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +class Storage { + public: + int Save(const schema::MetaGraphT &graph, const std::string &outputPath); + + schema::MetaGraphT *Load(const std::string &inputPath); +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_COMMON_STORAGE_H_ + diff --git a/mindspore/lite/tools/common/tensor_util.cc b/mindspore/lite/tools/common/tensor_util.cc new file mode 100644 index 00000000000..e41d5efd2c9 --- /dev/null +++ b/mindspore/lite/tools/common/tensor_util.cc @@ -0,0 +1,191 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/common/utils.h" +#include "tools/common/tensor_util.h" +#include "tools/common/graph_util.h" + +namespace mindspore { +namespace lite { +std::unique_ptr CopyQuantParamArrayT(const std::unique_ptr &srcQuantParamArray) { + MS_ASSERT(srcQuantParamArray != nullptr); + auto dstQuantParamArrayT = std::unique_ptr(new (std::nothrow) QuantParamT()); + if (dstQuantParamArrayT == nullptr) { + // MS_LOG(ERROR)("new dstQuantParamArrayT failed"); + return nullptr; + } + /* + for (size_t i = 0; i < srcQuantParamArray->param.size(); i++) { + auto &srcQuantParam = srcQuantParamArray->param.at(i); + MS_ASSERT(srcQuantParam != nullptr); + std::unique_ptr dstQuantParam(new (std::nothrow) QuantParamT()); + if (dstQuantParam == nullptr) { + //MS_LOG(ERROR)("new dstQuantParam failed"); + dstQuantParamArrayT.release(); + return nullptr; + } + dstQuantParam->scale = srcQuantParam->scale; + dstQuantParam->zeroPoint = srcQuantParam->zeroPoint; + dstQuantParam->min = srcQuantParam->min; + dstQuantParam->max = srcQuantParam->max; + dstQuantParam->narrowRange = srcQuantParam->narrowRange; + dstQuantParam->numBits = srcQuantParam->numBits; + dstQuantParamArrayT->param.emplace_back(std::move(dstQuantParam)); + } + */ + return std::move(dstQuantParamArrayT); +} + +std::unique_ptr GetInTensorQuantParamArray(const MetaGraphT &graphT, size_t tensorIdx) { + auto preNodeIdxes = GetLinkedPreIdx(graphT, tensorIdx); + MS_ASSERT(preNodeIdxes.size() <= 1); + if (preNodeIdxes.empty()) { + // MS_LOGD("the %zuth tensor has no preNode", tensorIdx); + return nullptr; + } + auto preNodeIdx = preNodeIdxes.front(); + MS_ASSERT(preNodeIdx < graphT.nodes.size()); + auto &preNode = graphT.nodes.at(preNodeIdx); + MS_ASSERT(preNode != nullptr); + MS_ASSERT(preNode->inputIndex.size() + preNode->outputIndex.size() == preNode->quantParam.size()); + /* + for (size_t i = 0; i < preNode->outputIndex.size(); i++) { + if (preNode->outputIndex.at(i) == tensorIdx) { + auto &quantPArray = preNode->quantParam.at(preNode->inputIndex.size() + i); + MS_ASSERT(quantPArray->param.size() == 1); // only support prelayer + MS_ASSERT(quantPArray->param.front() != nullptr); + if (quantPArray->param.front()->min == FLT_MAX) { + //MS_LOGD("the %zuth tensor's preNode's relative quantParam has not be inited", tensorIdx); + return nullptr; + } else { + return std::move(CopyQuantParamArrayT(quantPArray)); + } + } + } + */ + MS_ASSERT(false); + return nullptr; +} + +std::unique_ptr GetOutTensorQuantParamArray(const MetaGraphT &graphT, size_t tensorIdx) { + auto postNodeIdxes = GetLinkedPostIdx(graphT, tensorIdx); + if (postNodeIdxes.empty()) { + // MS_LOGD("the %zuth tensor has no postNode", tensorIdx); + return nullptr; + } + // find one postNode which can give valid quantParamArray + for (auto postNodeIdx : postNodeIdxes) { + MS_ASSERT(postNodeIdx < graphT.nodes.size()); + auto &postNode = graphT.nodes.at(postNodeIdx); + MS_ASSERT(postNode != nullptr); + MS_ASSERT(postNode->inputIndex.size() + postNode->outputIndex.size() == postNode->quantParam.size()); + /* + for (size_t i = 0; i < postNode->inputIndex.size(); i++) { + if (postNode->inputIndex.at(i) == tensorIdx) { + auto &quantPArray = postNode->quantParam.at(i); + MS_ASSERT(quantPArray->param.size() == 1); // only support prelayer + MS_ASSERT(quantPArray->param.front() != nullptr); + // check if postNode has valid quantParam + if (quantPArray->param.front()->min == FLT_MAX) { + continue; + } + MS_ASSERT(graphT.allTensors.size() > postNode->inputIndex.at(i)); + auto &tensor = graphT.allTensors.at(postNode->inputIndex.at(i)); + MS_ASSERT(tensor != nullptr); + if (tensor->refCount == schema::NodeType_ValueNode) { + continue; + } + // find valid quantParam return + auto paramArray = CopyQuantParamArrayT(quantPArray); + if (paramArray == nullptr) { + //MS_LOG(ERROR)("CopyQuantParamArrayT return nullptr"); + return nullptr; + } + return std::move(paramArray); + } + }*/ + } + return nullptr; +} + +size_t GetElementSize(const TensorT &tensor) { return GetElementSize(TypeId(tensor.dataType)); } + +size_t GetElementSize(const TypeId &dataType) { + switch (dataType) { + case kNumberTypeUInt8: + return sizeof(uint8_t); + case kNumberTypeInt32: + return sizeof(int32_t); + case kNumberTypeFloat: + return sizeof(float); + case kNumberTypeInt16: + return sizeof(int16_t); + case kNumberTypeInt8: + return sizeof(int8_t); + case kNumberTypeUInt32: + return sizeof(uint32_t); + default: + return sizeof(float); + } +} + +size_t GetShapeSize(const TensorT &tensor) { + auto shape = tensor.dims; + size_t shapeSize = 1; + for (auto dim : shape) { + shapeSize *= dim; + } + return shapeSize; +} + +std::unique_ptr CopyTensorDefT(const std::unique_ptr &oldTensor) { + auto newTensor = std::unique_ptr(new (std::nothrow) TensorT); + if (newTensor == nullptr) { + // MS_LOG(ERROR)("new TensorT failed"); + return nullptr; + } + newTensor->dims = oldTensor->dims; + newTensor->format = oldTensor->format; + newTensor->dataType = oldTensor->dataType; + newTensor->refCount = oldTensor->refCount; + newTensor->nodeType = oldTensor->nodeType; + newTensor->data = oldTensor->data; + return std::move(newTensor); +} + +size_t GetRefCount(MetaGraphT *graphT, uint32_t tensorIdx) { + MS_ASSERT(graphT != nullptr); + MS_ASSERT(graphT->allTensors.size() > tensorIdx); + size_t refCount = 0; + for (auto &node : graphT->nodes) { + MS_ASSERT(node != nullptr); + if (IsContain(node->inputIndex, tensorIdx)) { + refCount++; + } + } + return refCount; +} +size_t GetShapeSize(const std::vector &shape) { + size_t shapeSize = 1; + for (auto dim : shape) { + shapeSize *= dim; + } + return shapeSize; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/common/tensor_util.h b/mindspore/lite/tools/common/tensor_util.h new file mode 100644 index 00000000000..93cb3520a1a --- /dev/null +++ b/mindspore/lite/tools/common/tensor_util.h @@ -0,0 +1,123 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_TENSOR_UTIL_H +#define MINDSPORE_PREDICT_TENSOR_UTIL_H + +#include +#include +#include +#include +#include +#include +#include "schema/inner/model_generated.h" +#include "utils/log_adapter.h" +#include "ir/dtype/type_id.h" + +namespace mindspore { +namespace lite { +using schema::TensorT; +using schema::MetaGraphT; +using schema::CNodeT; +using schema::QuantParamT; +using schema::Format; +using schema::FusedBatchNormT; +using schema::Format_NCHW; +using schema::Format_NHWC; +using STATUS = int; +size_t GetElementSize(const TensorT &tensor); + +size_t GetElementSize(const TypeId &dataType); + +size_t GetShapeSize(const TensorT &tensor); + +size_t GetShapeSize(const std::vector &shape); + +std::unique_ptr CopyTensorDefT(const std::unique_ptr &); + +size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx); + +std::unique_ptr \ + CopyQuantParamArrayT(const std::unique_ptr &srcQuantParamArray); + +std::unique_ptr GetInTensorQuantParamArray(const schema::MetaGraphT &graphT, size_t tensorIdx); + +std::unique_ptr GetOutTensorQuantParamArray(const schema::MetaGraphT &graphT, size_t tensorIdx); + +using MSGraphDefTPtr = std::shared_ptr; + +enum TensorType { CONST = 0, GRAPH_INPUT = 1, OP_OUTPUT = 2, TF_CONST = 3 }; + +class TensorCache { + public: + TensorCache() {} + + ~TensorCache() { tensors.clear(); } + + int AddTensor(const std::string &name, TensorT *tensor, int TensorType) { + index++; + if (TensorType == CONST || TensorType == TF_CONST || TensorType == GRAPH_INPUT) { + tensor->refCount = 1; + tensor->nodeType = schema::NodeType_ValueNode; + } else { + tensor->nodeType = schema::NodeType_Parameter; + } + tensors.push_back(tensor); + + if (TensorType == GRAPH_INPUT) { + graphInputs.push_back(index); + } + + if (TensorType == GRAPH_INPUT || TensorType == OP_OUTPUT || TensorType == TF_CONST) { + UpdateTensorIndex(name, index); + } + return index; + } + + // find the name index + int FindTensor(const std::string &name) { + auto iter = tensorIndex.find(name); + if (iter != tensorIndex.end()) { + return iter->second; + } + return -1; + } + + void UpdateTensorIndex(const std::string &name, int index) { + auto iter = tensorIndex.find(name); + if (iter != tensorIndex.end()) { + tensorIndex[name] = index; + } else { + tensorIndex.insert(make_pair(name, index)); + } + } + + // return allTensors + const std::vector &GetCachedTensor() const { return tensors; } + + const std::vector &GetGraphInputs() const { return graphInputs; } + + private: + std::vector tensors; + std::unordered_map tensorIndex; + std::vector graphInputs; + int index = -1; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_TENSOR_UTIL_H + diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt new file mode 100644 index 00000000000..856e730b87f --- /dev/null +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -0,0 +1,109 @@ +set(ANF_SRC + ${ANF_SRC} + # core/abstract + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/abstract_function.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/analysis_context.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/param_validator.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/abstract_value.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/dshape.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/abstract/utils.cc + # core/base + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/base/base_ref.cc + # core/ir + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/anf.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/anf_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/meta_func_graph.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/func_graph.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/graph_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../ccsrc/utils/func_graph_cloner.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/func_graph_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/manager.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/primitive.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/tensor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/visitor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/meta_tensor_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/named.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/scope.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/value_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/container.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/empty.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/number.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/ref.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/type.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/ir/dtype/type_extends.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/utils/any.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/utils/symbolic.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/utils/misc.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../core/utils/flags.cc + ## ccsrc + ${CCSRC_DIR}/debug/info.cc + ${CCSRC_DIR}/debug/trace_base.cc + ${CCSRC_DIR}/debug/trace_info.cc + ${CCSRC_DIR}/debug/label.cc + ${CCSRC_DIR}/debug/draw.cc + ${CCSRC_DIR}/pybind_api/export_flags.cc + ${CCSRC_DIR}/utils/profile.cc + ${CCSRC_DIR}/utils/context/ms_context.cc + ${CCSRC_DIR}/frontend/parallel/costmodel_context.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/graph_utils_extends.cc + ) + +file(GLOB_RECURSE OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/*.cc) + +file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../flag/flag_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/converter.cc + ${CMAKE_CURRENT_SOURCE_DIR}/converter_flags.cc + ${CMAKE_CURRENT_SOURCE_DIR}/anf_transform.cc + ${CMAKE_CURRENT_SOURCE_DIR}/graphdef_transform.cc + ${CMAKE_CURRENT_SOURCE_DIR}/optimizer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../common/node_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/anf_importer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/import_from_meta_graphT.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_importer/import_from_protobuf.cc + ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.pb.cc + + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/node_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/optimizer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/pass_manager.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/pattern_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/visit.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_biasadd_fusion.cc + ) + +add_subdirectory(parser/caffe) +add_subdirectory(parser/tflite) +add_subdirectory(optimizer) +add_subdirectory(quantizer) + +add_executable(converter_lite + main.cc + ${ANF_SRC} + ${CONVERTER_SRC} + ${OPS_SRC} + ) +target_link_libraries(converter_lite PRIVATE + tflite_parser_mid + caffe_parser_mid + anf_exporter_mid + node_mid + graph_pass_mid + fusion_mid + protobuf + quantizer_mid + pthread + mindspore-lite + ${SECUREC_LIBRARY} + mindspore::json + mindspore::eigen + ) + diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc new file mode 100644 index 00000000000..8684ddac7e8 --- /dev/null +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/anf_transform.h" +#include +#include +#include "utils/log_adapter.h" +#include "src/gllo/fusion/conv_biasadd_fusion.h" + + +using std::string; +namespace mindspore { +namespace lite { +AnfTransform::AnfTransform() = default; + +AnfTransform::~AnfTransform() = default; + +void AnfTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; } + +FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph) { + // return old_graph; + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto pass = std::make_shared(); + pm->AddPass(pass); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(old_graph); + return new_graph; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/anf_transform.h b/mindspore/lite/tools/converter/anf_transform.h new file mode 100644 index 00000000000..3b393a15bcd --- /dev/null +++ b/mindspore/lite/tools/converter/anf_transform.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ANF_TRANSFORM_H +#define MS_ANF_TRANSFORM_H + +#include "schema/inner/model_generated.h" +#include "tools/common/storage.h" +#include "tools/converter/converter_flags.h" +#include "ir/anf.h" + + +namespace mindspore { +namespace lite { +class AnfTransform { + public: + AnfTransform(); + virtual ~AnfTransform(); + FuncGraphPtr Transform(const FuncGraphPtr &old_graph); + void SetGraphDef(schema::MetaGraphT *dstDef); + inline schema::MetaGraphT *GetOutput() { return graphDefT; } + + protected: + schema::MetaGraphT *graphDefT = nullptr; +}; +} // namespace lite +} // namespace mindspore + +#endif + diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc new file mode 100644 index 00000000000..c31fb88fd37 --- /dev/null +++ b/mindspore/lite/tools/converter/converter.cc @@ -0,0 +1,195 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/converter.h" +#include +#include +#include +#include "tools/converter/converter_flags.h" +#include "src/common/common.h" +#include "src/common/file_utils.h" +#include "ir/func_graph.h" + +#include "utils/log_adapter.h" +#include "tools/common/storage.h" +#include "parser/caffe/caffe_converter.h" +#include "parser/tflite/tflite_converter.h" +#include "src/common/anf_exporter/anf_exporter.h" +#include "src/common/anf_importer/import_from_protobuf.h" +#include "tools/converter/parser/onnx/onnx.pb.h" +#include "tools/converter/quantizer/weight_quantizer.h" +#include "tools/converter/quantizer/post_training.h" + +namespace mindspore { +namespace lite { +using FmkType = converter::FmkType; +Converter::Converter() { + this->transform = new GraphDefTransform; + this->anfTransform = new AnfTransform; +} + +Converter::~Converter() { + if (nullptr != modelParser) { + delete modelParser; + } + if (nullptr != modelImporter) { + delete modelImporter; + } + if (nullptr != transform) { + delete transform; + } + if (nullptr != anfTransform) { + delete anfTransform; + } +} + +class MindsporeImporter : public Converter { + public: + MindsporeImporter(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) { + modelImporter = new AnfImporterFromProtobuf(onnx_model, std::move(func_graph)); + } + + ~MindsporeImporter() override = default; +}; + +MetaGraphT *Converter::Convert(const converter::Flags *flag) { + // parse the model and weight file to generate inference data structure + FuncGraphPtr graph = nullptr; + if (flag->fmk == converter::FmkType_MS) { + MS_ASSERT(nullptr != modelImporter); + modelImporter->Import(); + graph = modelImporter->GetResult(); + } else { + MS_ASSERT(nullptr != modelParser); + const std::string modelFile = flag->modelFile; + const std::string weightFile = flag->weightFile; + auto meta_graph = modelParser->Parse(modelFile, weightFile); + if (meta_graph == nullptr) { + MS_LOG(ERROR) << "Parse to metaGraph return nullptr"; + return nullptr; + } + // todo hangangqiang + graph = ModelParser::Fb2Anf(meta_graph); + } + if (graph == nullptr) { + MS_LOG(ERROR) << "Parser/Import model return nullptr"; + return nullptr; + } + +// auto newGraph = anfTransform->Transform(graph); + /* + CreateQuantizer(graph, flag); + if (mQuantizer != nullptr) { + auto status = mQuantizer->DoQuantize(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "Quant failed " << status; + return nullptr; + } + } + */ + // anf -- fb + auto meta_graph = Export(graph); + if (meta_graph == nullptr) { + MS_LOG(ERROR) << "Export to meta_graph return nullptr"; + return nullptr; + } + + // transform + transform->SetGraphDef(meta_graph); + auto status = transform->Transform(*flag); + if (status != 0) { + MS_LOG(ERROR) << "FBTransform model failed " << status; + return nullptr; + } + return meta_graph; +} +void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags) { + auto type = flags->quantType; + switch (type) { + case mindspore::schema::QuantType_AwareTrainning: { + // mQuantizer.reset(new AwareQuantizer(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean)); + break; + } + case mindspore::schema::QuantType_WeightQuant: { + MS_LOG(INFO) << "create WeightQuantizer!"; + mQuantizer.reset( + new quant::WeightQuantizer(funcGraph, flags->quantSize, flags->convWeightQuantChannelThreshold, flags->bitNum)); + break; + } + case mindspore::schema::QuantType_PostTraining: { + MS_LOG(INFO) << "create PostTrainningQuantizer!"; + mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8)); + break; + } + case mindspore::schema::QuantType_QUANT_NONE: + MS_LOG(INFO) << "Not do quantization for model!"; + break; + default: + MS_LOG(INFO) << "will support quntizer type " << flags->quantTypeIn.c_str() << " in the future!"; + break; + } +} +int RunConverter(int argc, const char **argv) { + auto flags = new converter::Flags; + auto status = flags->Init(argc, argv); + if (status != 0) { + MS_LOG(ERROR) << "converter::Flags Init failed: " << status; + return 1; + } + // Load graph + std::string modelName = flags->modelFile.substr(flags->modelFile.find_last_of(DELIM_SLASH) + 1); + MS_LOG(INFO) << "start reading model file"; + + MetaGraphT *fb_graph = nullptr; + switch (flags->fmk) { + case FmkType::FmkType_MS: { + auto graph = std::make_shared(); + auto onnx_graph = AnfImporterFromProtobuf::ReadOnnxFromBinary(flags->modelFile); + MindsporeImporter mindsporeImporter(onnx_graph, graph); + fb_graph = mindsporeImporter.Convert(flags); + break; + } + case FmkType::FmkType_CAFFE: { + CaffeConverter caffeConverter; + fb_graph = caffeConverter.Convert(flags); + } break; + case FmkType::FmkType_TFLITE: { + TfliteConverter tfLiteConverter; + fb_graph = tfLiteConverter.Convert(flags); + } break; + default: { + MS_LOG(ERROR) << "Unsupported fmkType: " << flags->fmk; + return 1; + } + } + if (fb_graph == nullptr) { + MS_LOG(ERROR) << "Convert model return nullptr"; + return 1; + } + + // save graph to file + Storage storage; + status = storage.Save(*fb_graph, flags->outputFile); + if (status != 0) { + MS_LOG(ERROR) << "Save graph failed"; + return 1; + } + MS_LOG(INFO) << "CONVERT RESULT: SUCCESS!"; + + return 0; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/converter.h b/mindspore/lite/tools/converter/converter.h new file mode 100644 index 00000000000..54e3560a876 --- /dev/null +++ b/mindspore/lite/tools/converter/converter.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_CONVERTER_H +#define MS_CONVERTER_H + +#include +#include +#include "schema/inner/model_generated.h" +#include "tools/converter/graphdef_transform.h" +#include "tools/converter/model_parser.h" +#include "src/common/anf_importer/anf_importer.h" +#include "tools/converter/converter_flags.h" +#include "tools/converter/anf_transform.h" +#include "tools/converter/quantizer/quantizer.h" + +namespace mindspore { +namespace lite { +class Converter { + public: + Converter(); + virtual ~Converter(); + virtual schema::MetaGraphT *Convert(const lite::converter::Flags *flags); + void CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags); + + protected: + ModelParser *modelParser = nullptr; + AnfImporter *modelImporter = nullptr; + GraphDefTransform *transform = nullptr; + AnfTransform *anfTransform = nullptr; + std::unique_ptr mQuantizer = nullptr; +}; + +int RunConverter(int argc, const char **argv); +} // namespace lite +} // namespace mindspore + +#endif + diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc new file mode 100644 index 00000000000..3fd7a9fc45c --- /dev/null +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -0,0 +1,176 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/converter_flags.h" + + +namespace mindspore { +namespace lite { +namespace converter { +Flags::Flags() { + AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | CAFFE | ONNX | MS | TFLITE", ""); + AddFlag(&Flags::modelFile, "modelFile", + "Input model file path. TF: *.pb | CAFFE: *.prototxt | ONNX: *.onnx | MS: *.ms", ""); + AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); + AddFlag(&Flags::weightFile, "weightFile", + "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); + AddFlag(&Flags::inferenceType, "inferenceType", + "Real data type saved in output file, reserved param, NOT used for now. FLOAT | FP16 | UINT8", "FLOAT"); + AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTrainning | WeightQuant | PostTraining", ""); + AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | UINT8", "FLOAT"); + AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); + AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "127"); + AddFlag(&Flags::quantSize, "quantSize", "Weight quantization size threshold", "0"); + AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); + AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true"); +} + +int Flags::Init(int argc, const char **argv) { + Option err = this->ParseFlags(argc, argv); + + if (err.IsSome()) { + MS_LOG(ERROR) << err.Get(); + std::cerr << this->Usage() << std::endl; + return 1; + } + + if (this->help) { + std::cerr << this->Usage() << std::endl; + return 0; + } + if (this->modelFile.empty()) { + MS_LOG(ERROR) << "INPUT MISSING: model file path is necessary"; + return 1; + } + if (this->outputFile.empty()) { + MS_LOG(ERROR) << "INPUT MISSING: output file path is necessary"; + return 1; + } + + if (this->outputFile.rfind('/') == this->outputFile.length() - 1) { + MS_LOG(ERROR) << "INPUT ILLEGAL: outputFile must be a valid file path"; + return 1; + } + + if (this->fmkIn.empty()) { + MS_LOG(ERROR) << "INPUT MISSING: fmk is necessary"; + return 1; + } + if (this->inputInferenceTypeIn == "FLOAT") { + this->inputInferenceType = 0; + } else if (this->inputInferenceTypeIn == "UINT8") { + this->inputInferenceType = 1; + } else { + MS_LOG(ERROR) << "INPUT INVALID: inputInferenceType is invalid: %s", this->inputInferenceTypeIn.c_str(); + return 1; + } + if (this->fmkIn == "TF") { + this->fmk = FmkType_TF; + } else if (this->fmkIn == "CAFFE") { + this->fmk = FmkType_CAFFE; + } else if (this->fmkIn == "ONNX") { + this->fmk = FmkType_ONNX; + } else if (this->fmkIn == "MS") { + this->fmk = FmkType_MS; + } else if (this->fmkIn == "TFLITE") { + this->fmk = FmkType_TFLITE; + } else { + MS_LOG(ERROR) << "INPUT ILLEGAL: fmk must be TF|CAFFE|ONNX|MS"; + return 1; + } + + if (this->fmk != FmkType_CAFFE && !weightFile.empty()) { + MS_LOG(ERROR) << "INPUT ILLEGAL: weightFile is not a valid flag"; + return 1; + } + if (this->quantTypeIn == "AwareTrainning") { + this->quantType = QuantType_AwareTrainning; + } else if (this->quantTypeIn == "WeightQuant") { + this->quantType = QuantType_WeightQuant; + } else if (this->quantTypeIn == "PostTraining") { + this->quantType = QuantType_PostTraining; + } else if (this->quantTypeIn.empty()) { + this->quantType = QuantType_QUANT_NONE; + } else { + MS_LOG(ERROR) << "INPUT ILLEGAL: quantType must be AwareTrainning|WeightQuant|PostTraining"; + return 1; + } + + // auto status = ValidateAwareQuantizerCLI(); + // if (status != RET_OK) { + // MS_PRINT_ERROR("Parse aware quantization command line failed: %d", status); + // return status; + // } + // status = ValidateWeighQuantCLI(); + // if (status != RET_OK) { + // MS_PRINT_ERROR("ValidateWeighQuantCLI failed: %d", status); + // return status; + // } + return 0; +} + +// bool Flags::ValidateString(const string pattern, const string input) { +// std::regex repPattern(pattern, std::regex_constants::extended); +// std::match_results regResult; +// return regex_match(input, regResult, repPattern); +//} + +// int Flags::ValidateAwareQuantizerCLI() { +// // check input inference type +// if (this->inputInferenceType == DataType_DT_FLOAT) { +// if (this->mean.empty()) { +// MS_PRINT_ERROR("mean value shound not be null!") +// return RET_PARAM_INVALID; +// } +// if (this->stdDev.empty()) { +// MS_PRINT_ERROR("standard deviation value shound not be null!") +// return RET_PARAM_INVALID; +// } +// const std::string pattern = "^[+-]?([0-9]*\.?[0-9]+|[0-9]+\.?[0-9]*)([eE][+-]?[0-9]+)?$"; +// if (!ValidateString(pattern, this->mean)) { +// MS_PRINT_ERROR("invalid input mean values: %s", this->mean.c_str()); +// return RET_PARAM_INVALID; +// } +// if (!ValidateString(pattern, this->stdDev)) { +// MS_PRINT_ERROR("invalid input standard deviation value: %s", this->stdDev.c_str()); +// return RET_PARAM_INVALID; +// } +// } else { +// if (!this->mean.empty()) { +// MS_PRINT_INFO("useless mean value: %s", this->mean.c_str()); +// } +// if (!this->stdDev.empty()) { +// MS_PRINT_INFO("useless stdDev value: %s", this->stdDev.c_str()); +// } +// } +// return RET_OK; +//} + +// int Flags::ValidateWeighQuantCLI() { +// if (!this->quantSize.empty()) { +// if (!ValidateString("^[0-9]*$", this->quantSize)) { +// MS_PRINT_ERROR("invalid input quantSize: %s, only support positive integer type!", this->quantSize.c_str()); +// return RET_PARAM_INVALID; +// } +// } +// return RET_OK; +//} +} // namespace converter +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/converter_flags.h b/mindspore/lite/tools/converter/converter_flags.h new file mode 100644 index 00000000000..b97d777ae13 --- /dev/null +++ b/mindspore/lite/tools/converter/converter_flags.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CONVERTER_FLAGS_H +#define CONVERTER_FLAGS_H + +#include +#include "tools/common/flag_parser.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +using mindspore::schema::QuantType; +using mindspore::schema::QuantType_PostTraining; +using mindspore::schema::QuantType_QUANT_NONE; +using mindspore::schema::QuantType_AwareTrainning; +using mindspore::schema::QuantType_WeightQuant; +using mindspore::schema::QuantType_PostTraining; +using mindspore::schema::QuantType_PostTraining; +namespace converter { +enum FmkType { + FmkType_TF = 0, + FmkType_CAFFE = 1, + FmkType_ONNX = 2, + FmkType_MS = 3, + FmkType_TFLITE = 4 +}; + +class Flags : public virtual mindspore::lite::FlagParser { + public: + Flags(); + + ~Flags() override = default; + + int Init(int argc, const char **argv); + + private: + bool ValidateString(std::string pattern, std::string input); + + // int ValidateAwareQuantizerCLI(); + // + // int ValidateWeighQuantCLI(); + + public: + std::string modelFile; + std::string outputFile; + std::string fmkIn; + FmkType fmk; + std::string weightFile; + std::string inputArrays; + std::string outputArrays; + std::string inputShapes; + // used for quantization + std::string quantTypeIn; + QuantType quantType; + std::string inferenceType; + // used for parse aware trainning + std::string inputInferenceTypeIn; + // mindspore::predict::DataType inputInferenceType = DataType_DT_FLOAT; + int inputInferenceType = 0; + std::string stdDev; + std::string mean; + // used for post-trainning-weight + std::string quantSize; + std::string bitNum; + std::string configFile; + bool formatTrans = true; + std::string convWeightQuantChannelThreshold; +}; +} // namespace converter +} // namespace lite +} // namespace mindspore + +#endif + diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc new file mode 100644 index 00000000000..cc5bf749f1d --- /dev/null +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -0,0 +1,183 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/graphdef_transform.h" +#include +#include +#include "schema/model_generated.h" +#include "utils/log_adapter.h" +#include "src/common/op_utils.h" +#include "tools/converter/converter_flags.h" +#include "tools/converter/optimizer/fusion/conv_bn_fusion_pass.h" +#include "tools/converter/optimizer/fusion/conv_scale_fusion_pass.h" +#include "tools/converter/optimizer/fusion/conv_relu_fusion_pass.h" +#include "tools/converter/optimizer/fusion/conv_relu6_fusion_pass.h" +#include "tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.h" +// #include "tools/converter/optimizer/fusion/matmul_biasadd_fusion_pass.h" +#include "tools/converter/optimizer/fusion/format_trans_fusion_pass.h" +// #include "tools/converter/optimizer/fusion/quant_cast_fusion_pass.h" +// #include "tools/converter/optimizer/fusion/batchnorm_fold_fusion_pass.h" +// +// #include "tools/converter/optimizer/const_fold/add_const_fold_pass.h" +// #include "tools/converter/optimizer/const_fold/cast_const_fold_pass.h" +// #include "tools/converter/optimizer/const_fold/concat_v2_const_fold_pass.h" +// #include "tools/converter/optimizer/const_fold/expand_dims_const_fold_pass.h" +// #include "tools/converter/optimizer/const_fold/mul_const_fold_pass.h" +// #include "tools/converter/optimizer/const_fold/range_const_fold_pass.h" +// #include "tools/converter/optimizer/const_fold/reshape_const_fold_pass.h" +// #include "tools/converter/optimizer/const_fold/rsqrt_const_fold_pass.h" +// #include "tools/converter/optimizer/const_fold/shape_const_fold_pass.h" +// #include "tools/converter/optimizer/const_fold/slice_const_fold_pass.h" +// #include "tools/converter/optimizer/const_fold/stack_const_fold_pass.h" +// #include "tools/converter/optimizer/const_fold/strided_slice_const_fold_pass.h" +// #include "tools/converter/optimizer/const_fold/sub_const_fold_pass.h" +// #include "tools/converter/optimizer/const_fold/tile_const_fold_pass.h" +// #include "tools/converter/optimizer/const_fold/transpose_const_fold_pass.h" +// +#include "tools/converter/optimizer/node/weight_format_pass.h" +#include "tools/converter/optimizer/graph/format_trans_pass.h" +#include "tools/converter/optimizer/graph/isolated_node_remove_pass.h" +#include "tools/converter/optimizer/graph/unused_node_remove_pass.h" +#include "tools/converter/optimizer/graph/topological_sort_pass.h" + +#include "tools/converter/converter.h" + +using std::string; +namespace mindspore { +namespace lite { +GraphDefTransform::GraphDefTransform() = default; + +GraphDefTransform::~GraphDefTransform() = default; + +void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; } + +int GraphDefTransform::Transform(const converter::Flags &ctx) { + STATUS status; + // // constant folding + // { + // Optimizer topologicalSortOptimizer; + // topologicalSortOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + // status = topologicalSortOptimizer.Run(graphDefT); + // if (status != RET_OK) { + // MS_LOG(ERROR)<<"Run topologicalSortOptimizer graphPasses Failed"; + // return status; + // } + // Optimizer constFoldOptimizer; + // constFoldOptimizer.AddPass(new (std::nothrow) AddConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) CastConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) ConcatV2ConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) ExpandDimsConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) MulConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) RangeConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) ReshapeConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) RsqrtConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) ShapeConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) SliceConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) StackConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) StridedSliceConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) SubConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) TileConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) TransposeConstFoldPass()); + // constFoldOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + // status = constFoldOptimizer.Run(graphDefT); + // if (status != RET_OK && status != RET_NO_CHANGE) { + // MS_LOG(ERROR) << "Run constFoldOptimizer graphPasses Failed"; + // return status; + // } + // } + + // fusion + { + Optimizer fusionOptimizer; + fusionOptimizer.AddPass(new (std::nothrow) ConvBiasAddFusionPass()); + fusionOptimizer.AddPass(new (std::nothrow) ConvBNFusionPass()); + fusionOptimizer.AddPass(new (std::nothrow) ConvScaleFusionPass()); + fusionOptimizer.AddPass(new (std::nothrow) ConvReluFusionPass()); + fusionOptimizer.AddPass(new (std::nothrow) ConvRelu6FusionPass()); + fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + status = fusionOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; + return status; + } + } + + // weight format trans + if (ctx.formatTrans) { + Optimizer weightFormatOptimizer; + auto weightFormatPass = new (std::nothrow) WeightFormatPass(); + if (weightFormatPass == nullptr) { + MS_LOG(ERROR) << "new weightFormatPass failed"; + return RET_ERROR; + } + // weightFormatPass->SetQuantType(ctx.quantType); + weightFormatPass->SetFmkType(ctx.fmk); + weightFormatOptimizer.AddPass(weightFormatPass); + status = weightFormatOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run weightFormatOptimizer graphPasses Failed"; + return status; + } + } + + // format transform + if (ctx.formatTrans) { + Optimizer formatTransOptimizer; + auto formatTransPass = new (std::nothrow) FormatTransPass(); + if (formatTransPass == nullptr) { + MS_LOG(ERROR) << "new formatTransPass failed"; + return RET_ERROR; + } + // formatTransPass->SetQuantType(ctx.quantType); + formatTransPass->SetFmk(ctx.fmk); + formatTransOptimizer.AddPass(formatTransPass); + formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); + formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + // if (ctx.quantType == QuantType_AwareTrainning) { + // formatTransOptimizer.AddPass(new (std::nothrow) FormatTransNodeQuantParamFillPass()); + // } + status = formatTransOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; + return status; + } + } + + { + Optimizer unusedOpRemoveOptimizer; + unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); + unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); + status = unusedOpRemoveOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; + return status; + } + } + // topological sorting + { + Optimizer topologicalOptimizer; + topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + status = topologicalOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; + return status; + } + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/graphdef_transform.h b/mindspore/lite/tools/converter/graphdef_transform.h new file mode 100644 index 00000000000..aed9feba5a1 --- /dev/null +++ b/mindspore/lite/tools/converter/graphdef_transform.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_GRAPHDEF_TRANSFORM_H +#define MS_GRAPHDEF_TRANSFORM_H + +#include "tools/converter/optimizer.h" +// #include "quantizer/quantizer.h" +#include "schema/inner/model_generated.h" +#include "tools/common/storage.h" +#include "tools/converter/converter_flags.h" + +namespace mindspore { +namespace lite { +/* + * transform GraphDef by fusion optimizer and quantizer + * */ + +class GraphDefTransform { + public: + GraphDefTransform(); + virtual ~GraphDefTransform(); + virtual int Transform(const converter::Flags &ctx); + void SetGraphDef(schema::MetaGraphT *dstDef); + inline schema::MetaGraphT *GetOutput() { return graphDefT; } + void CreateQuantizer(const converter::Flags *flags); + + protected: + schema::MetaGraphT *graphDefT = nullptr; + Optimizer *optimizer = nullptr; + + // std::unique_ptr mQuantizer; +}; +} // namespace lite +} // namespace mindspore + +#endif + diff --git a/mindspore/lite/tools/converter/main.cc b/mindspore/lite/tools/converter/main.cc new file mode 100644 index 00000000000..6923ed75c16 --- /dev/null +++ b/mindspore/lite/tools/converter/main.cc @@ -0,0 +1,20 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/converter.h" + +int main(int argc, const char **argv) { return mindspore::lite::RunConverter(argc, argv); } + diff --git a/mindspore/lite/tools/converter/model_parser.h b/mindspore/lite/tools/converter/model_parser.h new file mode 100644 index 00000000000..f9014fbc4c8 --- /dev/null +++ b/mindspore/lite/tools/converter/model_parser.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_MODEL_PARSER_H +#define MS_MODEL_PARSER_H +#include +#include +#include +#include "schema/inner/model_generated.h" +#include "src/common/anf_importer/import_from_meta_graphT.h" +#include "ir/anf.h" +#include "include/errorcode.h" + +namespace mindspore::lite { +using namespace schema; +class ModelParser { + public: + ModelParser() {} + + virtual ~ModelParser() {} + + virtual FuncGraphPtr ParseToAnf(const std::string &modelFile, const std::string &weightFile) { + auto *meta_graph = Parse(modelFile, weightFile); + if (meta_graph == nullptr) { + MS_LOG(ERROR) << "Parse to metaGraph return nullptr"; + return nullptr; + } + return Fb2Anf(Parse(modelFile, weightFile)); + } + virtual schema::MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) = 0; + + public: + static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) { + MS_EXCEPTION_IF_NULL(meta_graph); + auto func_graph = std::make_shared(); + auto importer = new AnfImporterFromMetaGraphT(meta_graph, func_graph); + auto ret = importer->Import(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "Import anf_graph from meta_graphT failed, ret: " << ret; + return nullptr; + } + return func_graph; + } +}; +} // namespace mindspore::lite + +#endif + + diff --git a/mindspore/lite/tools/converter/optimizer.cc b/mindspore/lite/tools/converter/optimizer.cc new file mode 100644 index 00000000000..d043138931c --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace lite { +Optimizer::~Optimizer() { + for (auto pass : graphPasses) { + if (pass != nullptr) { + delete (pass); + } + } + + for (auto pass : nodePasses) { + if (pass != nullptr) { + delete (pass); + } + } +} + +void Optimizer::AddPass(GraphPass *graphPass) { + if (graphPass != nullptr) { + this->graphPasses.emplace_back(graphPass); + } +} + +void Optimizer::AddPass(NodePass *nodePass) { + if (nodePass != nullptr) { + this->nodePasses.emplace_back(nodePass); + } +} + +STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) { + STATUS status; + bool ifNotChanged = true; + // each node should go through all node pass not each node pass go through all node + for (auto &opDef : graphDefT->nodes) { + for (auto pass : this->nodePasses) { + status = pass->Run(new GraphNode(graphDefT, opDef.get())); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run NodePass failed"; + return status; + } else { + if (status == RET_OK) { + ifNotChanged = false; + } + } + } + } + + for (auto pass : this->graphPasses) { + status = pass->Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run GraphPass failed"; + return status; + } else { + if (status == RET_OK) { + ifNotChanged = false; + } + } + } + return ifNotChanged ? RET_NO_CHANGE : RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer.h b/mindspore/lite/tools/converter/optimizer.h new file mode 100644 index 00000000000..346e8c60161 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer.h @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_OPTIMIZER_H +#define MS_OPTIMIZER_H +#include +#include "schema/inner/model_generated.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +using namespace schema; +template +class Pass { + public: + Pass() = default; + virtual ~Pass() = default; + virtual STATUS Run(T *t) = 0; +}; + +class GraphPass : public Pass { + public: + GraphPass() = default; + + ~GraphPass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override = 0; + + // protected: + // GraphDefT *graphDefT = nullptr; +}; + +struct GraphNode { + GraphNode(schema::MetaGraphT *subGraph, schema::CNodeT *opDefT) : subGraph(subGraph), opDef(opDefT) {} + ~GraphNode() = default; + schema::MetaGraphT *subGraph = nullptr; + schema::CNodeT *opDef = nullptr; +}; + +class NodePass : public Pass { + public: + NodePass() = default; + + ~NodePass() override = default; + + STATUS Run(GraphNode *graphNode) override = 0; + + // protected: + // GraphNode *graphNode = nullptr; +}; + +class Optimizer { + public: + Optimizer() = default; + + virtual ~Optimizer(); + + void AddPass(GraphPass *graphPass); + + void AddPass(NodePass *nodePass); + + STATUS Run(schema::MetaGraphT *graphDefT); + + private: + std::vector graphPasses; + std::vector nodePasses; +}; +} // namespace lite +} // namespace mindspore + +#endif + + diff --git a/mindspore/lite/tools/converter/optimizer/CMakeLists.txt b/mindspore/lite/tools/converter/optimizer/CMakeLists.txt new file mode 100755 index 00000000000..898d0607388 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/CMakeLists.txt @@ -0,0 +1,6 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + +add_subdirectory(fusion) +#add_subdirectory(const_fold) +add_subdirectory(node) +add_subdirectory(graph) \ No newline at end of file diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/CMakeLists.txt b/mindspore/lite/tools/converter/optimizer/const_fold/CMakeLists.txt new file mode 100644 index 00000000000..fdd03aca274 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/CMakeLists.txt @@ -0,0 +1,50 @@ +set(OP_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/tensor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/context.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/runtime/allocator.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/op.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/op_common.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/op_factory.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/op_registry.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/common/op_func_comm.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/common/op_nc4hw4_comm.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/add.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/cast.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/concat.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/fp32/add_fp32.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/fp32/concat_fp32.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/uint8/add_uint8.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/uint8/concat_uint8.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/expand_dim.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/mul.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/range.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/reshape.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/uint8/reshape_uint8.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/rsqrt.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/shape.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/slice.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/stack.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/strided_slice.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/sub.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/tile.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/operator/cpu/creator/transpose.cc + ) + +add_library(const_fold_mid OBJECT + ${OP_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/add_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/cast_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/concat_v2_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/expand_dims_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/mul_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/range_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/reshape_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rsqrt_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/shape_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/slice_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/stack_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/strided_slice_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/sub_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/tile_const_fold_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/transpose_const_fold_pass.cc) diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/add_const_fold_pass.cc b/mindspore/lite/tools/converter/optimizer/const_fold/add_const_fold_pass.cc new file mode 100644 index 00000000000..f74ba7e2689 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/add_const_fold_pass.cc @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/const_fold/add_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "src/operator/cpu/creator/add.h" + +namespace mindspore { +namespace lite { + +STATUS AddConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS AddConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Add; + desc.arch = kCPU; + MS_ASSERT(inputs.size() == kArithOpInputNum); + auto inTensor0 = inputs.at(kArithOpInputTensorIndex0); + auto inTensor1 = inputs.at(kArithOpInputTensorIndex1); + MS_ASSERT(inTensor0 != nullptr); + MS_ASSERT(inTensor1 != nullptr); + DataType dataType; + if (inTensor0->GetNDim() > 1) { + dataType = inTensor0->GetDataType(); + } else { + dataType = inTensor1->GetDataType(); + } + switch (dataType) { + case DataType_DT_UINT8: { + op = new (std::nothrow) OpAdd(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_INT32: { + op = new (std::nothrow) OpAdd(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_FLOAT: { + op = new (std::nothrow) OpAdd(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_INT8: { + op = new (std::nothrow) OpAdd(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_UINT32: { + op = new (std::nothrow) OpAdd(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + default: { + MS_LOGE("Unsupported dataType: %d", dataType); + return RET_ERROR; + } + } + if (op == nullptr) { + MS_LOGE("new OpAdd return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpAdd InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpAdd Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS AddConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpAdd Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kArithOpOutputNum) { + MS_LOGE("The number of output for add must be %u, nodeName: %s", kArithOpOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/add_const_fold_pass.h b/mindspore/lite/tools/converter/optimizer/const_fold/add_const_fold_pass.h new file mode 100644 index 00000000000..28f758a7558 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/add_const_fold_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_ADD_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_ADD_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" +#include "converter/common/tensor_util.h" + +namespace mindspore { +namespace lite { +class AddConstFoldPass : public ConstFoldPass { + public: + AddConstFoldPass() : ConstFoldPass(OpT_Add) {} + + ~AddConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_ADD_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/cast_const_fold_pass.cc b/mindspore/lite/tools/converter/optimizer/const_fold/cast_const_fold_pass.cc new file mode 100644 index 00000000000..821018b7d32 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/cast_const_fold_pass.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/const_fold/cast_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "src/operator/cpu/creator/cast.h" + +#define CAST_OUTPUT_NUM 1 + +namespace mindspore { +namespace lite { +STATUS CastConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS CastConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Cast; + desc.arch = kCPU; + op = new (std::nothrow) OpCast(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpCast return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpCast InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpCast Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS CastConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpCast Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != CAST_OUTPUT_NUM) { + MS_LOGE("The number of output for cast must be %u, nodeName: %s", CAST_OUTPUT_NUM, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/cast_const_fold_pass.h b/mindspore/lite/tools/converter/optimizer/const_fold/cast_const_fold_pass.h new file mode 100644 index 00000000000..65c07dca058 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/cast_const_fold_pass.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_CAST_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_CAST_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" + +namespace mindspore { +namespace lite { +class CastConstFoldPass : public ConstFoldPass { + public: + CastConstFoldPass() : ConstFoldPass(OpT_Cast) {} + + ~CastConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CAST_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/concat_v2_const_fold_pass.cc b/mindspore/lite/tools/converter/optimizer/const_fold/concat_v2_const_fold_pass.cc new file mode 100644 index 00000000000..360181fb8a1 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/concat_v2_const_fold_pass.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/const_fold/concat_v2_const_fold_pass.h" +#include "src/operator/cpu/creator/concat.h" + +namespace mindspore { +namespace lite { + +STATUS ConcatV2ConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS ConcatV2ConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Concat; + desc.arch = kCPU; + op = new (std::nothrow) OpConcat(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpConcat return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpConcat InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpConcat Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS ConcatV2ConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpConcat Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kConcatOutputNum) { + MS_LOGE("The number of output for concat must be %u, nodeName: %s", kConcatOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/concat_v2_const_fold_pass.h b/mindspore/lite/tools/converter/optimizer/const_fold/concat_v2_const_fold_pass.h new file mode 100644 index 00000000000..5833892e809 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/concat_v2_const_fold_pass.h @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_CONCAT_V2_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_CONCAT_V2_CONST_FOLD_PASS_H + +#include +#include "converter/optimizer/const_fold/const_fold_pass.h" +#include "converter/common/tensor_util.h" +#include "utils/log_adapter.h" +#include "securec/include/securec.h" + +namespace mindspore { +namespace lite { +class ConcatV2ConstFoldPass : public ConstFoldPass { + public: + ConcatV2ConstFoldPass() : ConstFoldPass(OpT_Concat) {} + + ~ConcatV2ConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; + + private: + template + STATUS DoConcat(SubGraphDefT *subGraph, const std::vector &inTensorIdxes, int axis) { + MS_ASSERT(this->outputTensor != nullptr); + std::vector inTensors; + std::vector inDatas; + for (size_t i = 0; i < inTensorIdxes.size(); i++) { + auto &inTensor = subGraph->allTensors.at(inTensorIdxes.at(i)); + MS_ASSERT(inTensor != nullptr); + inTensors.emplace_back(inTensor.get()); + void *inData = inTensor->data.data(); + MS_ASSERT(inData != nullptr); + T *castedInData = static_cast(inData); + MS_ASSERT(castedInData != nullptr); + inDatas.emplace_back(castedInData); + } + auto &inShape = subGraph->allTensors.at(inTensorIdxes.at(0))->dims; + std::vector outputDims; + for (size_t i = 0; i < inShape.size(); i++) { + if (i == axis) { + int32_t axisDim = 0; + for (size_t j = 0; j < inTensors.size(); j++) { + axisDim += inTensors.at(j)->dims.at(i); + } + outputDims.push_back(axisDim); + continue; + } + outputDims.push_back(inShape.at(i)); + } + + size_t outShapeSize = 1; + for (auto dim : outputDims) { + outShapeSize *= dim; + } + size_t elementSize = GetElementSize(subGraph->allTensors.at(inTensorIdxes.at(0))->dataType); + + this->outputTensor->dims = outputDims; + this->outputTensor->data.clear(); + this->outputTensor->data.resize(outShapeSize * elementSize); + + void *outData = this->outputTensor->data.data(); + MS_ASSERT(outData != nullptr); + T *castedOutData = static_cast(outData); + + size_t copyBlockTile = 1; + for (int i = axis + 1; i < inShape.size(); i++) { + copyBlockTile *= inShape[i]; + } + std::vector inCopyBlocks; + size_t outCopyBlock = 0; + for (size_t i = 0; i < inTensors.size(); i++) { + inCopyBlocks.emplace_back(copyBlockTile * (inTensors.at(i)->dims.at(axis))); + outCopyBlock += inCopyBlocks.back(); + } + + size_t outIndex = 0; + while (outIndex < outShapeSize) { + for (size_t i = 0; i < inDatas.size(); i++) { + ::memcpy_s(castedOutData + outIndex, inCopyBlocks.at(i), inDatas.at(i), inCopyBlocks.at(i)); + outIndex += inCopyBlocks.at(i); + inDatas.at(i) += inCopyBlocks.at(i); + } + } + + return RET_OK; + } +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CONCAT_V2_CONST_FOLD_PASS_H diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/const_fold_pass.cc b/mindspore/lite/tools/converter/optimizer/const_fold/const_fold_pass.cc new file mode 100644 index 00000000000..30905c33dad --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/const_fold_pass.cc @@ -0,0 +1,207 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/const_fold/const_fold_pass.h" +#include +#include "utils/log_adapter.h" +#include "converter/common/graph_util.h" + +namespace mindspore { +namespace lite { +STATUS ConstFoldPass::Run(GraphNode *graphNode) { + MS_ASSERT(graphNode != nullptr); + auto subGraph = graphNode->subGraph; + auto node = graphNode->opDef; + MS_ASSERT(subGraph != nullptr); + MS_ASSERT(node != nullptr); + if (GetOpType(*node) != opType) { + return RET_OK; + } + if (!IsFoldable(subGraph, node)) { + MS_LOGD("All input should be ConstTensor, node : %s"); + return RET_OK; + } + + for (uint32_t i : node->inputIndex) { + TensorDefT *tensorDefT = subGraph->allTensors.at(i).get(); + MS_ASSERT(tensorDefT != nullptr); + auto tensor = CopyTensorDefT2Tensor(tensorDefT); + if (tensor == nullptr) { + MS_LOGE("Pack TensorDefT return nullptr"); + FreeTensors(); + return RET_ERROR; + } + inputs.emplace_back(tensor); + } + for (uint32_t i : node->outputIndex) { + TensorDefT *tensorDefT = subGraph->allTensors.at(i).get(); + MS_ASSERT(tensorDefT != nullptr); + auto tensor = CopyTensorDefT2Tensor(tensorDefT, false); + if (tensor == nullptr) { + MS_LOGE("Pack TensorDefT return nullptr"); + FreeTensors(); + return RET_ERROR; + } + outputs.emplace_back(tensor); + } + + auto status = CreateOp(subGraph, node); + if (status != RET_OK) { + MS_LOGE("CreateOp error: %d, node: %s", status, node->name.c_str()); + FreeTensors(); + return status; + } + for (auto &outputTensor : outputs) { + auto statusTmp = outputTensor->MallocData(); + if (statusTmp != RET_OK) { + MS_LOGE("OutTensor MallocData error: %d, nodeName: %s", statusTmp, node->name.c_str()); + FreeTensors(); + return RET_ERROR; + } + } + status = DoFold(subGraph, node); + if (status != RET_OK) { + MS_LOGE("DoFold error: %d, node: %s", status, node->name.c_str()); + FreeTensors(); + return status; + } + + if (this->outputTensor->data.empty()) { + MS_LOGI("outputTensor's data has not been set, node : %s", node->name.c_str()); + FreeTensors(); + return RET_OK; + } + this->outputTensor->refCount = schema::NodeType_ValueNode; + bool isSubNode = false; + for (auto &inNode : subGraph->nodes) { + if (inNode->name == node->name) { + isSubNode = true; + break; + } + } + if (!isSubNode) { + MS_LOGE("Node %s is not in subGraph %s", node->name.c_str(), subGraph->name.c_str()); + return RET_PARAM_INVALID; + } else { + status = RemoveTensor(subGraph, node->inputIndex); + if (status != RET_OK) { + MS_LOGE("RemoveTensor failed, node : %s", node->name.c_str()); + FreeTensors(); + return status; + } + // we can not erase nodes in iter loop, so just isolate the node + node->inputIndex.clear(); + node->outputIndex.clear(); + } + + FreeTensors(); + return RET_OK; +} + +OpDef *ConstFoldPass::PackOpDefT(const OpDefT *opDefT) { + flatbuffers::FlatBufferBuilder builder(1024); + auto offset = OpDef::Pack(builder, opDefT); + builder.Finish(offset); + auto buf = builder.GetBufferPointer(); + auto opDef = flatbuffers::GetRoot(buf); + return const_cast(opDef); +} + +Tensor *ConstFoldPass::CopyTensorDefT2Tensor(const TensorDefT *tensorDefT, bool needCopyData) { + if (tensorDefT == nullptr) { + MS_LOGE("tensorDefT is null"); + return nullptr; + } + std::vector dims; + for (size_t i = 0; i < tensorDefT->dims.size(); i++) { + dims.emplace_back(tensorDefT->dims.at(i)); + } + + auto tensor = new (std::nothrow) Tensor(tensorDefT->dataType, dims, tensorDefT->format, nullptr); + if (tensor == nullptr) { + MS_LOGE("new tensor error"); + return nullptr; + } + if (needCopyData) { + auto status = tensor->MallocData(); + if (status != RET_OK) { + MS_LOGE("malloc tensor data error: %d", status); + delete (tensor); + return nullptr; + } + size_t dataLength = tensor->GetDataSize(); + status = ::memcpy_s(tensor->GetData(), dataLength, tensorDefT->data.data(), dataLength); + if (status != 0) { + MS_LOGE("memcpy_s error: %d", status); + delete (tensor); + return nullptr; + } + } + return tensor; +} + +STATUS ConstFoldPass::CopyTensor2TensorDefT(const Tensor *tensor, TensorDefT *tensorDefT) { + MS_ASSERT(tensorDefT != nullptr); + if (tensor == nullptr) { + MS_LOGE("tensor is null"); + return RET_ERROR; + } + + tensorDefT->dims.clear(); + for (size_t i = 0; i < tensor->GetNDim(); i++) { + tensorDefT->dims.emplace_back(tensor->GetDims().at(i)); + } + tensorDefT->dataType = tensor->GetDataType(); + tensorDefT->format = tensor->GetFormat(); + size_t dataLength = tensor->GetDataSize(); + tensorDefT->data.resize(dataLength); + auto ret = ::memcpy_s(tensorDefT->data.data(), dataLength, tensor->GetData(), dataLength); + if (ret != 0) { + MS_LOGE("memcpy_s error: %d", ret); + return RET_ERROR; + } + return RET_OK; +} + +bool ConstFoldPass::IsFoldable(SubGraphDefT *subGraph, OpDefT *node) { + bool isFoldable = true; + for (auto tensorIdx : node->inputIndex) { + auto &tensor = subGraph->allTensors.at(tensorIdx); + if (tensor->refCount != schema::NodeType_ValueNode || tensor->data.empty()) { + isFoldable = false; + break; + } + } + return isFoldable; +} + +void ConstFoldPass::FreeTensors() { + for (auto tensor : inputs) { + if (tensor != nullptr) { + delete (tensor); + } + } + inputs.clear(); + for (auto tensor : outputs) { + if (tensor != nullptr) { + delete (tensor); + } + } + outputs.clear(); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/const_fold_pass.h b/mindspore/lite/tools/converter/optimizer/const_fold/const_fold_pass.h new file mode 100644 index 00000000000..cbc0103598b --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/const_fold_pass.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_CONST_FOLD_PASS_H + +#include +#include "mindspore/lite/tools/converter/optimizer.h" +#include "include/tensor.h" +#include "utils/log_adapter.h" +#include "converter/common/converter_op_utils.h" +#include "securec/include/securec.h" +#include "src/op.h" + +namespace mindspore { +namespace lite { +class ConstFoldPass : public NodePass { + public: + explicit ConstFoldPass(schema::PrimitiveType opType) : opType(opType) {} + + ~ConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + protected: + bool IsFoldable(SubGraphDefT *subGraph, OpDefT *node); + + virtual STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) = 0; + + virtual STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) = 0; + + protected: + OpDef *PackOpDefT(const OpDefT *opDefT); + + Tensor *CopyTensorDefT2Tensor(const TensorDefT *tensorDefT, bool needCopyData = true); + + STATUS CopyTensor2TensorDefT(const Tensor *tensor, TensorDefT *tensorDefT); + + void FreeTensors(); + + protected: + schema::PrimitiveType opType; + TensorDefT *outputTensor = nullptr; + std::vector inputs; + std::vector outputs; + OpBase *op = nullptr; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CONST_FOLD_PASS_H diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/expand_dims_const_fold_pass.cc b/mindspore/lite/tools/converter/optimizer/const_fold/expand_dims_const_fold_pass.cc new file mode 100644 index 00000000000..9bf262d0f1a --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/expand_dims_const_fold_pass.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/const_fold/expand_dims_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "src/operator/cpu/creator/expand_dim.h" + +namespace mindspore { +namespace lite { +STATUS ExpandDimsConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS ExpandDimsConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_ExpandDims; + desc.arch = kCPU; + op = new (std::nothrow) OpExpandDim(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpExpandDim return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpExpandDim InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpExpandDim Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS ExpandDimsConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpExpandDim Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kExpandDimsOutputNum) { + MS_LOGE("The number of output for expandDim must be %u, nodeName: %s", kExpandDimsOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/expand_dims_const_fold_pass.h b/mindspore/lite/tools/converter/optimizer/const_fold/expand_dims_const_fold_pass.h new file mode 100644 index 00000000000..12e9d979a10 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/expand_dims_const_fold_pass.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_EXPANDDIMS_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_EXPANDDIMS_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" + +namespace mindspore { +namespace lite { +class ExpandDimsConstFoldPass : public ConstFoldPass { + public: + ExpandDimsConstFoldPass() : ConstFoldPass(OpT_ExpandDims) {} + + ~ExpandDimsConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_EXPANDDIMS_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/mul_const_fold_pass.cc b/mindspore/lite/tools/converter/optimizer/const_fold/mul_const_fold_pass.cc new file mode 100644 index 00000000000..e75308ec6dc --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/mul_const_fold_pass.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/const_fold/mul_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "converter/common/tensor_util.h" +#include "converter/common/converter_op_utils.h" +#include "src/operator/cpu/creator/mul.h" + +namespace mindspore { +namespace lite { +STATUS MulConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS MulConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Mul; + desc.arch = kCPU; + MS_ASSERT(inputs.size() == kArithOpInputNum); + auto inTensor0 = inputs.at(kArithOpInputTensorIndex0); + auto inTensor1 = inputs.at(kArithOpInputTensorIndex1); + MS_ASSERT(inTensor0 != nullptr); + MS_ASSERT(inTensor1 != nullptr); + DataType dataType; + if (inTensor0->GetNDim() > 1) { + dataType = inTensor0->GetDataType(); + } else { + dataType = inTensor1->GetDataType(); + } + op = nullptr; + switch (dataType) { + case DataType_DT_UINT8: { + op = new (std::nothrow) OpMul(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_INT32: { + op = new (std::nothrow) OpMul(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_FLOAT: { + op = new (std::nothrow) OpMul(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_INT8: { + op = new (std::nothrow) OpMul(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_UINT32: { + op = new (std::nothrow) OpMul(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + default: { + MS_LOGE("Unsupported dataType: %d", dataType); + return RET_ERROR; + } + } + if (op == nullptr) { + MS_LOGE("new OpMul return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpMul InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpMul Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS MulConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpMul Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kArithOpOutputNum) { + MS_LOGE("The number of output for mul must be %u, nodeName: %s", kArithOpOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/mul_const_fold_pass.h b/mindspore/lite/tools/converter/optimizer/const_fold/mul_const_fold_pass.h new file mode 100644 index 00000000000..6c47b60c639 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/mul_const_fold_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_MUL_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_MUL_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" +#include "converter/common/tensor_util.h" + +namespace mindspore { +namespace lite { +class MulConstFoldPass : public ConstFoldPass { + public: + MulConstFoldPass() : ConstFoldPass(OpT_Mul) {} + + ~MulConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_MUL_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/range_const_fold_pass.cc b/mindspore/lite/tools/converter/optimizer/const_fold/range_const_fold_pass.cc new file mode 100644 index 00000000000..6552a7af48e --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/range_const_fold_pass.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/const_fold/range_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "src/operator/cpu/creator/range.h" + +namespace mindspore { +namespace lite { +#define kRangeOutputNum 1 + +STATUS RangeConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS RangeConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Range; + desc.arch = kCPU; + op = new (std::nothrow) OpRange(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpAdd return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpAdd InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpAdd Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS RangeConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpAdd Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kRangeOutputNum) { + MS_LOGE("The number of range for range must be %u, nodeName: %s", kRangeOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/range_const_fold_pass.h b/mindspore/lite/tools/converter/optimizer/const_fold/range_const_fold_pass.h new file mode 100644 index 00000000000..e8b48e40042 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/range_const_fold_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_RANGE_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_RANGE_CONST_FOLD_PASS_H + +#include +#include "converter/optimizer/const_fold/const_fold_pass.h" + +namespace mindspore { +namespace lite { +class RangeConstFoldPass : public ConstFoldPass { + public: + RangeConstFoldPass() : ConstFoldPass(OpT_Range) {} + + ~RangeConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_RANGE_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/reshape_const_fold_pass.cc b/mindspore/lite/tools/converter/optimizer/const_fold/reshape_const_fold_pass.cc new file mode 100644 index 00000000000..2a42fd0633f --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/reshape_const_fold_pass.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/const_fold/reshape_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "src/operator/cpu/creator/reshape.h" + +namespace mindspore { +namespace lite { +STATUS ReshapeConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS ReshapeConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Reshape; + desc.arch = kCPU; + op = new (std::nothrow) OpReshape(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpReshape return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpReshape InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpReshape Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS ReshapeConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpReshape Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kReshapeOutputNum) { + MS_LOGE("The number of output for Reshape must be %u, nodeName: %s", kReshapeOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/reshape_const_fold_pass.h b/mindspore/lite/tools/converter/optimizer/const_fold/reshape_const_fold_pass.h new file mode 100644 index 00000000000..13a79503956 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/reshape_const_fold_pass.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_RESHAPE_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_RESHAPE_CONST_FOLD_PASS_H + +#include +#include "converter/optimizer/const_fold/const_fold_pass.h" + +namespace mindspore { +namespace lite { +class ReshapeConstFoldPass : public ConstFoldPass { + public: + ReshapeConstFoldPass() : ConstFoldPass(OpT_Reshape) {} + + ~ReshapeConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; + + private: + STATUS CalNewShape(const TensorDefT &inTensor, std::vector &outShape); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_RESHAPE_CONST_FOLD_PASS_H diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/rsqrt_const_fold_pass.cc b/mindspore/lite/tools/converter/optimizer/const_fold/rsqrt_const_fold_pass.cc new file mode 100644 index 00000000000..1f912eccd2a --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/rsqrt_const_fold_pass.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/const_fold/rsqrt_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "src/operator/cpu/fp32/rsqrt_fp32.h" + +namespace mindspore { +namespace lite { + +STATUS RsqrtConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS RsqrtConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Rsqrt; + desc.arch = kCPU; + op = new (std::nothrow) RsqrtFp32(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpRsqrt return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpRsqrt InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpRsqrt Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS RsqrtConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpRsqrt Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kRsqrtOutputNum) { + MS_LOGE("The number of output for Rsqrt must be %u, nodeName: %s", kRsqrtOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/rsqrt_const_fold_pass.h b/mindspore/lite/tools/converter/optimizer/const_fold/rsqrt_const_fold_pass.h new file mode 100644 index 00000000000..7ce1fc16112 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/rsqrt_const_fold_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_RSQRT_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_RSQRT_CONST_FOLD_PASS_H + +#include +#include "converter/optimizer/const_fold/const_fold_pass.h" + +namespace mindspore { +namespace lite { +class RsqrtConstFoldPass : public ConstFoldPass { + public: + RsqrtConstFoldPass() : ConstFoldPass(OpT_Rsqrt) {} + + ~RsqrtConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_RSQRT_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/shape_const_fold_pass.cc b/mindspore/lite/tools/converter/optimizer/const_fold/shape_const_fold_pass.cc new file mode 100644 index 00000000000..d3d18b8715f --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/shape_const_fold_pass.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/const_fold/shape_const_fold_pass.h" +#include "src/operator/cpu/creator/shape.h" + +namespace mindspore { +namespace lite { +STATUS ShapeConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS ShapeConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Shape; + desc.arch = kCPU; + op = new (std::nothrow) OpShape(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpShape return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpShape InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpShape Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS ShapeConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpShape Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kShapeOutputNum) { + MS_LOGE("The number of output for shape must be %u, nodeName: %s", kShapeOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/shape_const_fold_pass.h b/mindspore/lite/tools/converter/optimizer/const_fold/shape_const_fold_pass.h new file mode 100644 index 00000000000..7f05a9b9e2c --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/shape_const_fold_pass.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_SHAPE_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_SHAPE_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" + +namespace mindspore { +namespace lite { +class ShapeConstFoldPass : public ConstFoldPass { + public: + ShapeConstFoldPass() : ConstFoldPass(OpT_Shape) {} + + ~ShapeConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_SHAPE_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/slice_const_fold_pass.cc b/mindspore/lite/tools/converter/optimizer/const_fold/slice_const_fold_pass.cc new file mode 100644 index 00000000000..88aa527635f --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/slice_const_fold_pass.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/const_fold/slice_const_fold_pass.h" +#include "src/operator/cpu/creator/slice.h" + +namespace mindspore { +namespace lite { +// todo if slice op has placeholder tensor +STATUS SliceConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS SliceConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Slice; + desc.arch = kCPU; + op = new (std::nothrow) OpSlice(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpSlice return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpSlice InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpSlice Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS SliceConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpSlice Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kSliceOutputNum) { + MS_LOGE("The number of output for slice must be %u, nodeName: %s", kSliceOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/slice_const_fold_pass.h b/mindspore/lite/tools/converter/optimizer/const_fold/slice_const_fold_pass.h new file mode 100644 index 00000000000..c5d7ca3470e --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/slice_const_fold_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_SLICE_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_SLICE_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" + +namespace mindspore { +namespace lite { +// This Op only supports 1-4D cases +class SliceConstFoldPass : public ConstFoldPass { + public: + SliceConstFoldPass() : ConstFoldPass(OpT_Slice) {} + + ~SliceConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_SLICE_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/stack_const_fold_pass.cc b/mindspore/lite/tools/converter/optimizer/const_fold/stack_const_fold_pass.cc new file mode 100644 index 00000000000..b562e2506d9 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/stack_const_fold_pass.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/const_fold/stack_const_fold_pass.h" +#include "src/operator/cpu/creator/stack.h" + +namespace mindspore { +namespace lite { +STATUS StackConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS StackConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Stack; + desc.arch = kCPU; + op = new (std::nothrow) OpStack(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpStack return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpStack InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpStack Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS StackConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpStack Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kStackOutputNum) { + MS_LOGE("The number of output for stack must be %u, nodeName: %s", kStackOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/stack_const_fold_pass.h b/mindspore/lite/tools/converter/optimizer/const_fold/stack_const_fold_pass.h new file mode 100644 index 00000000000..2f366696161 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/stack_const_fold_pass.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_STACK_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_STACK_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" +#include "securec/include/securec.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace lite { +class StackConstFoldPass : public ConstFoldPass { + public: + StackConstFoldPass() : ConstFoldPass(OpT_Stack) {} + + ~StackConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_STACK_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/strided_slice_const_fold_pass.cc b/mindspore/lite/tools/converter/optimizer/const_fold/strided_slice_const_fold_pass.cc new file mode 100644 index 00000000000..abc84ffcafd --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/strided_slice_const_fold_pass.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/const_fold/strided_slice_const_fold_pass.h" +#include "src/operator/cpu/creator/strided_slice.h" + +namespace mindspore { +namespace lite { +STATUS StridedSliceConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS StridedSliceConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Slice; + desc.arch = kCPU; + op = new (std::nothrow) OpStridedSlice(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpStridedSlice return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpStridedSlice InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpStridedSlice Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS StridedSliceConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpStridedSlice Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kStridedSliceOutputNum) { + MS_LOGE("The number of output for slice must be %u, nodeName: %s", kStridedSliceOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/strided_slice_const_fold_pass.h b/mindspore/lite/tools/converter/optimizer/const_fold/strided_slice_const_fold_pass.h new file mode 100644 index 00000000000..bbb33871411 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/strided_slice_const_fold_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_STRIDED_SLICE_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_STRIDED_SLICE_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" + +namespace mindspore { +namespace lite { +// This Op only supports 1-4D cases +class StridedSliceConstFoldPass : public ConstFoldPass { + public: + StridedSliceConstFoldPass() : ConstFoldPass(OpT_StridedSlice) {} + + ~StridedSliceConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_STRIDED_SLICE_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/sub_const_fold_pass.cc b/mindspore/lite/tools/converter/optimizer/const_fold/sub_const_fold_pass.cc new file mode 100644 index 00000000000..fd4c8e4305c --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/sub_const_fold_pass.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/const_fold/sub_const_fold_pass.h" + +#include "utils/log_adapter.h" +#include "converter/common/tensor_util.h" +#include "converter/common/converter_op_utils.h" +#include "src/operator/cpu/creator/sub.h" + +namespace mindspore { +namespace lite { + +STATUS SubConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS SubConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Sub; + desc.arch = kCPU; + MS_ASSERT(inputs.size() == kArithOpInputNum); + auto inTensor0 = inputs.at(kArithOpInputTensorIndex0); + auto inTensor1 = inputs.at(kArithOpInputTensorIndex1); + MS_ASSERT(inTensor0 != nullptr); + MS_ASSERT(inTensor1 != nullptr); + DataType dataType; + if (inTensor0->GetNDim() > 1) { + dataType = inTensor0->GetDataType(); + } else { + dataType = inTensor1->GetDataType(); + } + switch (dataType) { + case DataType_DT_UINT8: { + op = new (std::nothrow) OpSub(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_INT32: { + op = new (std::nothrow) OpSub(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_FLOAT: { + op = new (std::nothrow) OpSub(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_INT8: { + op = new (std::nothrow) OpSub(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + case DataType_DT_UINT32: { + op = new (std::nothrow) OpSub(inputs, outputs, *PackOpDefT(node), &ctx, desc); + } break; + default: { + MS_LOGE("Unsupported dataType: %d", dataType); + return RET_ERROR; + } + } + if (op == nullptr) { + MS_LOGE("new OpSub return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpSub InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpSub Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS SubConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpSub Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kArithOpOutputNum) { + MS_LOGE("The number of output for sub must be %u, nodeName: %s", kArithOpOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/sub_const_fold_pass.h b/mindspore/lite/tools/converter/optimizer/const_fold/sub_const_fold_pass.h new file mode 100644 index 00000000000..2feecb29541 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/sub_const_fold_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_SUB_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_SUB_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" +#include "converter/common/tensor_util.h" + +namespace mindspore { +namespace lite { +class SubConstFoldPass : public ConstFoldPass { + public: + SubConstFoldPass() : ConstFoldPass(OpT_Sub) {} + + ~SubConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_SUB_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/tile_const_fold_pass.cc b/mindspore/lite/tools/converter/optimizer/const_fold/tile_const_fold_pass.cc new file mode 100644 index 00000000000..0f122fa08c8 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/tile_const_fold_pass.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/const_fold/tile_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "src/operator/cpu/creator/tile.h" + +namespace mindspore { +namespace lite { +STATUS TileConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS TileConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Tile; + desc.arch = kCPU; + op = new (std::nothrow) OpTile(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpTile return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpTile InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpTile Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS TileConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpTile Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kTileOutputNum) { + MS_LOGE("The number of output for tile must be %u, nodeName: %s", kTileOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/tile_const_fold_pass.h b/mindspore/lite/tools/converter/optimizer/const_fold/tile_const_fold_pass.h new file mode 100644 index 00000000000..df7404b8cde --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/tile_const_fold_pass.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_TILE_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_TILE_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" +#include "utils/log_adapter.h" +#include "securec/include/securec.h" + +namespace mindspore { +namespace lite { +class TileConstFoldPass : public ConstFoldPass { + public: + TileConstFoldPass() : ConstFoldPass(OpT_Tile) {} + + ~TileConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_TILE_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/transpose_const_fold_pass.cc b/mindspore/lite/tools/converter/optimizer/const_fold/transpose_const_fold_pass.cc new file mode 100644 index 00000000000..5a171efa2b2 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/transpose_const_fold_pass.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/const_fold/transpose_const_fold_pass.h" +#include "utils/log_adapter.h" +#include "src/operator/cpu/creator/transpose.h" + +namespace mindspore { +namespace lite { + +STATUS TransposeConstFoldPass::Run(GraphNode *graphNode) { return ConstFoldPass::Run(graphNode); } + +STATUS TransposeConstFoldPass::CreateOp(SubGraphDefT *subGraph, OpDefT *node) { + InnerContext ctx; + OpDesc desc{}; + desc.type = OpT_Transpose; + desc.arch = kCPU; + op = new (std::nothrow) OpTranspose(inputs, outputs, *PackOpDefT(node), &ctx, desc); + if (op == nullptr) { + MS_LOGE("new OpTranspose return nullptr"); + return RET_ERROR; + } + auto ret = op->InferShape(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpTranspose InferShape Failed"); + return RET_ERROR; + } + ret = op->Init(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpTranspose Init Failed"); + return RET_ERROR; + } + return RET_OK; +} + +STATUS TransposeConstFoldPass::DoFold(SubGraphDefT *subGraph, OpDefT *node) { + MS_ASSERT(op != nullptr); + auto ret = op->Execute(inputs, outputs); + if (ret != RET_OK) { + MS_LOGE("OpTranspose Execute Failed"); + return RET_ERROR; + } + + if (node->outputIndex.size() != kTransposeOutputNum) { + MS_LOGE("The number of output for transpose must be %u, nodeName: %s", kTransposeOutputNum, node->name.c_str()); + return RET_ERROR; + } + this->outputTensor = subGraph->allTensors.at(node->outputIndex.front()).get(); + CopyTensor2TensorDefT(outputs.front(), this->outputTensor); + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/const_fold/transpose_const_fold_pass.h b/mindspore/lite/tools/converter/optimizer/const_fold/transpose_const_fold_pass.h new file mode 100644 index 00000000000..902b564c899 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/const_fold/transpose_const_fold_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_TRANSPOSE_CONST_FOLD_PASS_H +#define MINDSPORE_PREDICT_TRANSPOSE_CONST_FOLD_PASS_H + +#include "converter/optimizer/const_fold/const_fold_pass.h" +#include "converter/common/tensor_util.h" + +namespace mindspore { +namespace lite { +class TransposeConstFoldPass : public ConstFoldPass { + public: + TransposeConstFoldPass() : ConstFoldPass(OpT_Transpose) {} + + ~TransposeConstFoldPass() override = default; + + STATUS Run(GraphNode *graphNode) override; + + STATUS CreateOp(SubGraphDefT *subGraph, OpDefT *node) override; + + STATUS DoFold(SubGraphDefT *subGraph, OpDefT *node) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_TRANSPOSE_CONST_FOLD_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/fusion/CMakeLists.txt b/mindspore/lite/tools/converter/optimizer/fusion/CMakeLists.txt new file mode 100755 index 00000000000..32aa9d4dac1 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/CMakeLists.txt @@ -0,0 +1,17 @@ +add_library(fusion_mid OBJECT + ${CMAKE_CURRENT_SOURCE_DIR}/fusion_pattern.cc + ${CMAKE_CURRENT_SOURCE_DIR}/fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/conv_scale_bias_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/conv_bn_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/conv_scale_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/conv_activation_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/conv_relu_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/conv_relu6_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/conv_biasadd_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/matmul_biasadd_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_fold_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_fusion_pass.cc + ) + +target_link_libraries(fusion_mid securec) diff --git a/mindspore/lite/tools/converter/optimizer/fusion/batchnorm_fold_fusion_pass.cc b/mindspore/lite/tools/converter/optimizer/fusion/batchnorm_fold_fusion_pass.cc new file mode 100644 index 00000000000..a932834c1b8 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/batchnorm_fold_fusion_pass.cc @@ -0,0 +1,500 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/fusion/batchnorm_fold_fusion_pass.h" +#include +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" +#include "tools/common/graph_util.h" +#include "tools/common/tensor_util.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" +#include "src/common/op_utils.h" + +namespace mindspore { +namespace lite { +#define kBatchNormFoldFusionPathLen6 6 +#define kBatchNormFoldFusionPathLen7 7 + +STATUS BatchNormFoldFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); } + +STATUS BatchNormFoldFusionPass::DefinePattern() { + // with preNode + { + auto inputOp = std::make_shared(); + inputOp->id = inputOpName; + inputOp->types = {schema::PrimitiveType_NONE}; + inputOp->isPlaceHold = true; + + auto convOp1 = std::make_shared(); + convOp1->id = convPatternOpName1; + convOp1->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; + convOp1->left = inputOp; + + auto bnFoldOp = std::make_shared(); + bnFoldOp->id = bnFoldOpName; + bnFoldOp->types = {schema::PrimitiveType_BatchNormFold}; + bnFoldOp->left = convOp1; + + auto mulFoldOp = std::make_shared(); + mulFoldOp->id = mulFoldOpName; + mulFoldOp->types = {schema::PrimitiveType_MulFold}; + mulFoldOp->left = bnFoldOp; + + auto fakeQuantOp = std::make_shared(); + fakeQuantOp->id = fakeQuantOpName; + fakeQuantOp->types = {schema::PrimitiveType_FakeQuantWithMinMax}; + fakeQuantOp->left = mulFoldOp; + + auto convOp2 = std::make_shared(); + convOp2->id = convPatternOpName2; + convOp2->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; + convOp2->left = fakeQuantOp; + convOp2->right = inputOp; + + auto addFoldOp = std::make_shared(); + addFoldOp->id = addFoldOpName; + addFoldOp->types = {schema::PrimitiveType_AddFold}; + addFoldOp->left = convOp2; + addFoldOp->right = bnFoldOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern(withPrePatternName)); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failed"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(inputOp); + fusionPattern->AddPatternOp(convOp1); + fusionPattern->AddPatternOp(bnFoldOp); + fusionPattern->AddPatternOp(mulFoldOp); + fusionPattern->AddPatternOp(fakeQuantOp); + fusionPattern->AddPatternOp(convOp2); + fusionPattern->AddPatternOp(addFoldOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + } + // no preNode + { + auto convOp1 = std::make_shared(); + convOp1->id = convPatternOpName1; + convOp1->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; + + auto bnFoldOp = std::make_shared(); + bnFoldOp->id = bnFoldOpName; + bnFoldOp->types = {schema::PrimitiveType_BatchNormFold}; + bnFoldOp->left = convOp1; + + auto mulFoldOp = std::make_shared(); + mulFoldOp->id = mulFoldOpName; + mulFoldOp->types = {schema::PrimitiveType_MulFold}; + mulFoldOp->left = bnFoldOp; + + auto fakeQuantOp = std::make_shared(); + fakeQuantOp->id = fakeQuantOpName; + fakeQuantOp->types = {schema::PrimitiveType_FakeQuantWithMinMax}; + fakeQuantOp->left = mulFoldOp; + + auto convOp2 = std::make_shared(); + convOp2->id = convPatternOpName2; + convOp2->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; + convOp2->left = fakeQuantOp; + + auto addFoldOp = std::make_shared(); + addFoldOp->id = addFoldOpName; + addFoldOp->types = {schema::PrimitiveType_AddFold}; + addFoldOp->left = convOp2; + addFoldOp->right = bnFoldOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern(noPrePatternName)); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failed"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(convOp1); + fusionPattern->AddPatternOp(bnFoldOp); + fusionPattern->AddPatternOp(mulFoldOp); + fusionPattern->AddPatternOp(fakeQuantOp); + fusionPattern->AddPatternOp(convOp2); + fusionPattern->AddPatternOp(addFoldOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + } + return RET_OK; +} + +STATUS BatchNormFoldFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + if (patternName == withPrePatternName) { + if (matchedPath.size() != kBatchNormFoldFusionPathLen7) { + MS_LOG(ERROR) << "BatchNormFold-Fusion should have seven NodeIndex in matchedPair"; + return RET_PARAM_INVALID; + } + } else if (patternName == noPrePatternName) { + if (matchedPath.size() != kBatchNormFoldFusionPathLen6) { + MS_LOG(ERROR) << "BatchNormFold-Fusion should have six NodeIndex in matchedPair"; + return RET_PARAM_INVALID; + } + } + + auto status = FindNodes(graph, matchedPath); + if (status != RET_OK) { + MS_LOG(ERROR) << "FindNodes failed: " << status; + return status; + } + status = CheckPath(graph, matchedPath); + if (status != RET_OK) { + MS_LOG(ERROR) << "CheckPath failed: " << status; + return status; + } + status = FindTensors(); + if (status != RET_OK) { + MS_LOG(ERROR) << "FindTensors failed: " << status; + return status; + } + status = GenNewWeightTensor(); + if (status != RET_OK) { + MS_LOG(ERROR) << "GenNewWeightTensor failed: " << status; + return status; + } + status = GenNewBiasTensor(); + if (status != RET_OK) { + MS_LOG(ERROR) << "GenNewBiasTensor failed: " << status; + return status; + } + status = IsolateNodes(graph, matchedPath); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateNodes failed: " << status; + return status; + } + UpdateConvWeights(); + status = DeleteConstTensors(); + if (status != RET_OK) { + MS_LOG(ERROR) << "DeleteConstTensors failed: " << status; + return status; + } + return RET_OK; +} + +STATUS BatchNormFoldFusionPass::FindNodes(MetaGraphT *graph, + const std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + auto preConvPath = matchedPath.at(convPatternOpName1); + auto bnFoldPath = matchedPath.at(bnFoldOpName); + auto mulFoldPath = matchedPath.at(mulFoldOpName); + auto fakeQuantPath = matchedPath.at(fakeQuantOpName); + auto convPath = matchedPath.at(convPatternOpName2); + auto addFoldPath = matchedPath.at(addFoldOpName); + MS_ASSERT(preConvPath != nullptr); + MS_ASSERT(bnFoldPath != nullptr); + MS_ASSERT(mulFoldPath != nullptr); + MS_ASSERT(fakeQuantPath != nullptr); + MS_ASSERT(convPath != nullptr); + MS_ASSERT(addFoldPath != nullptr); + if (preConvPath->subGraphIdx != bnFoldPath->subGraphIdx || preConvPath->subGraphIdx != mulFoldPath->subGraphIdx || + preConvPath->subGraphIdx != fakeQuantPath->subGraphIdx || preConvPath->subGraphIdx != convPath->subGraphIdx || + preConvPath->subGraphIdx != addFoldPath->subGraphIdx) { + MS_LOG(ERROR) << "matched nodes should from same subGraph"; + return RET_ERROR; + } + MS_ASSERT(graph->nodes.size() > preConvPath->nodeIdx); + MS_ASSERT(graph->nodes.size() > bnFoldPath->nodeIdx); + MS_ASSERT(graph->nodes.size() > mulFoldPath->nodeIdx); + MS_ASSERT(graph->nodes.size() > fakeQuantPath->nodeIdx); + MS_ASSERT(graph->nodes.size() > convPath->nodeIdx); + MS_ASSERT(graph->nodes.size() > addFoldPath->nodeIdx); + preConv = graph->nodes.at(preConvPath->nodeIdx).get(); + bnFold = graph->nodes.at(bnFoldPath->nodeIdx).get(); + mulFold = graph->nodes.at(mulFoldPath->nodeIdx).get(); + fakeNode = graph->nodes.at(fakeQuantPath->nodeIdx).get(); + convNode = graph->nodes.at(convPath->nodeIdx).get(); + addFold = graph->nodes.at(addFoldPath->nodeIdx).get(); + MS_ASSERT(preConv != nullptr); + MS_ASSERT(bnFold != nullptr); + MS_ASSERT(mulFold != nullptr); + MS_ASSERT(fakeNode != nullptr); + MS_ASSERT(convNode != nullptr); + MS_ASSERT(addFold != nullptr); + return RET_OK; +} + +STATUS BatchNormFoldFusionPass::FindTensors() { + MS_ASSERT(graph != nullptr); + MS_ASSERT(bnFold != nullptr); + MS_ASSERT(addFold != nullptr); + if (bnFold->inputIndex.size() != 4) { + MS_LOG(ERROR) << "BatchNormFold node should have 4 inputTensor, got " << bnFold->inputIndex.size() + << " input tensors"; + return RET_ERROR; + } + if (addFold->inputIndex.size() != 5) { + MS_LOG(ERROR) << "AddFold node should have 5 inputTensor, got " << addFold->inputIndex.size() << " input tensors"; + return RET_ERROR; + } + MS_ASSERT(graph->allTensors.size() > bnFold->inputIndex.at(1)); + muTensor = graph->allTensors.at(bnFold->inputIndex.at(1)).get(); + MS_ASSERT(muTensor != nullptr); + MS_ASSERT(graph->allTensors.size() > bnFold->inputIndex.at(2)); + sigmaTensor = graph->allTensors.at(bnFold->inputIndex.at(2)).get(); + MS_ASSERT(sigmaTensor != nullptr); + MS_ASSERT(graph->allTensors.size() > addFold->inputIndex.at(1)); + betaTensor = graph->allTensors.at(addFold->inputIndex.at(1)).get(); + MS_ASSERT(betaTensor != nullptr); + MS_ASSERT(graph->allTensors.size() > addFold->inputIndex.at(2)); + gammaTensor = graph->allTensors.at(addFold->inputIndex.at(2)).get(); + MS_ASSERT(gammaTensor != nullptr); + + if (betaTensor->dims.size() != 1) { + MS_LOG(ERROR) << "ConstTensor should have only one dim, got " << betaTensor->dims.size(); + return RET_ERROR; + } + if (betaTensor->dims != gammaTensor->dims || betaTensor->dims != sigmaTensor->dims || + betaTensor->dims != muTensor->dims) { + MS_LOG(ERROR) << "All ConstTensor should have same dims"; + return RET_ERROR; + } + channelOut = betaTensor->dims.front(); + + MS_ASSERT(mulFold != nullptr); + if (mulFold->inputIndex.size() != 3) { + MS_LOG(ERROR) << "MulFold node should have 3 outputTensor, got " << addFold->inputIndex.size() << " output tensors"; + return RET_ERROR; + } + MS_ASSERT(graph->allTensors.size() > mulFold->inputIndex.front()); + oldWeightTensor = graph->allTensors.at(mulFold->inputIndex.front()).get(); + MS_ASSERT(oldWeightTensor != nullptr); + return RET_OK; +} + +STATUS BatchNormFoldFusionPass::CheckPath(MetaGraphT *graph, + const std::unordered_map> &matchedPath) { + MS_ASSERT(preConv != nullptr); + MS_ASSERT(convNode != nullptr); + MS_ASSERT(mulFold != nullptr); + MS_ASSERT(preConv->inputIndex.size() == 2); + MS_ASSERT(convNode->inputIndex.size() == 2); + MS_ASSERT(mulFold->inputIndex.size() == 3); + MS_ASSERT(preConv->inputIndex.front() == convNode->inputIndex.front()); + MS_ASSERT(preConv->inputIndex.at(1) == mulFold->inputIndex.front()); + // todo + return RET_OK; +} + +STATUS BatchNormFoldFusionPass::GenNewWeightTensor() { + MS_ASSERT(oldWeightTensor != nullptr); + MS_ASSERT(oldWeightTensor->dataType == DataType_DT_FLOAT); + MS_ASSERT(oldWeightTensor->refCount == schema::NodeType_ValueNode); + auto weightShape = oldWeightTensor->dims; + if (weightShape.size() != 4) { + MS_LOG(ERROR) << "shape of weight should be 4 dims, got " << weightShape.size() << " dims"; + return RET_ERROR; + } + if (weightShape.front() != channelOut) { + MS_LOG(ERROR) << "weight should be in KCHW format, and outputChannel should be " << channelOut; + return RET_ERROR; + } + auto weightShapeSize = GetShapeSize(*oldWeightTensor); + newWeightTensor = std::unique_ptr(new (std::nothrow) TensorT); + if (newWeightTensor == nullptr) { + MS_LOG(ERROR) << "new weightTensor failed"; + return RET_ERROR; + } + newWeightTensor->dataType = oldWeightTensor->dataType; + newWeightTensor->format = oldWeightTensor->format; + newWeightTensor->refCount = schema::NodeType_ValueNode; + newWeightTensor->dims = weightShape; + newWeightTensor->data.resize(weightShapeSize * sizeof(float)); + void *oldWeightData = oldWeightTensor->data.data(); + auto castedOldWeightData = static_cast(oldWeightData); + void *newWeightData = newWeightTensor->data.data(); + auto castedNewWeightData = static_cast(newWeightData); + MS_ASSERT(gammaTensor->dataType == DataType_DT_FLOAT); + void *gammaData = gammaTensor->data.data(); + auto *castedGammaData = static_cast(gammaData); + MS_ASSERT(muTensor->dataType == DataType_DT_FLOAT); + void *miData = muTensor->data.data(); + auto *castedMiData = static_cast(miData); + size_t stride = weightShapeSize / channelOut; + for (size_t i = 0; i < channelOut; i++) { + for (size_t j = 0; j < stride; j++) { + castedNewWeightData[i * stride + j] = castedOldWeightData[i * stride + j] * castedGammaData[i] / castedMiData[i]; + } + } + return RET_OK; +} + +STATUS BatchNormFoldFusionPass::GenNewBiasTensor() { // bias has no quant + std::vector biasShape = {channelOut}; + newBiasTensor = std::unique_ptr(new (std::nothrow) TensorT); + if (newBiasTensor == nullptr) { + MS_LOG(ERROR) << "new BiasTensor failed"; + return RET_ERROR; + } + newBiasTensor->dataType = 0; // todo is float + newBiasTensor->format = Format_NUM_OF_FORMAT; + newBiasTensor->refCount = schema::NodeType_ValueNode; + newBiasTensor->dims = biasShape; + newBiasTensor->data.resize(channelOut * sizeof(float)); + void *newBiasData = newBiasTensor->data.data(); + auto castedNewBiasData = static_cast(newBiasData); + MS_ASSERT(betaTensor->dataType == DataType_DT_FLOAT); + void *betaData = betaTensor->data.data(); + auto *castedBetaData = static_cast(betaData); + MS_ASSERT(gammaTensor->dataType == DataType_DT_FLOAT); + void *gammaData = gammaTensor->data.data(); + auto *castedGammaData = static_cast(gammaData); + MS_ASSERT(muTensor->dataType == DataType_DT_FLOAT); + void *miData = muTensor->data.data(); + auto *castedMiData = static_cast(miData); + MS_ASSERT(sigmaTensor->dataType == DataType_DT_FLOAT); + void *sigmaData = sigmaTensor->data.data(); + auto *castedSigmaData = static_cast(sigmaData); + for (size_t i = 0; i < channelOut; i++) { + castedNewBiasData[i] = castedBetaData[i] - castedGammaData[i] * castedMiData[i] / castedSigmaData[i]; + } + return RET_OK; +} + +STATUS BatchNormFoldFusionPass::IsolateNodes( + MetaGraphT *graph, const std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + auto preConvPath = matchedPath.at(convPatternOpName1); + auto bnFoldPath = matchedPath.at(bnFoldOpName); + auto mulFoldPath = matchedPath.at(mulFoldOpName); + auto fakeQuantPath = matchedPath.at(fakeQuantOpName); + auto convPath = matchedPath.at(convPatternOpName2); + auto addFoldPath = matchedPath.at(addFoldOpName); + MS_ASSERT(preConvPath != nullptr); + MS_ASSERT(bnFoldPath != nullptr); + MS_ASSERT(mulFoldPath != nullptr); + MS_ASSERT(fakeQuantPath != nullptr); + MS_ASSERT(convPath != nullptr); + MS_ASSERT(addFoldPath != nullptr); + auto status = IsolateOneWayNode(graph, preConvPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode " << preConv->name.c_str() << " failed, error: " << status; + return status; + } + std::vector toDeleteTensorIdxes; + toDeleteTensorIdxes.emplace_back(bnFold->inputIndex.at(3)); + toDeleteTensorIdxes.insert(toDeleteTensorIdxes.end(), bnFold->outputIndex.begin(), bnFold->outputIndex.end()); + status = RemoveTensor(graph, toDeleteTensorIdxes, true); + if (status != RET_OK) { + MS_LOG(ERROR) << "Remove Tensors of BnFold " << bnFold->name.c_str() << " failed, error: " << status; + return RET_ERROR; + } + status = IsolateOneWayNode(graph, bnFoldPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode " << bnFold->name.c_str() << " failed, error: " << status; + return status; + } + status = IsolateOneWayNode(graph, mulFoldPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode " << mulFold->name.c_str() << " failed, error: " << status; + return status; + } + status = IsolateOneWayNode(graph, addFoldPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode " << addFold->name.c_str() << " failed, error: " << status; + return status; + } + return RET_OK; +} + +void BatchNormFoldFusionPass::UpdateConvWeights() { + MS_ASSERT(graph != nullptr); + MS_ASSERT(convNode != nullptr); + MS_ASSERT(newWeightTensor != nullptr); + MS_ASSERT(newBiasTensor != nullptr); + MS_ASSERT(graph->allTensors.size() > fakeNode->inputIndex.at(0)); + graph->allTensors.at(fakeNode->inputIndex.at(0)).reset(); + graph->allTensors.at(fakeNode->inputIndex.at(0)) = std::move(this->newWeightTensor); + graph->allTensors.emplace_back(std::move(this->newBiasTensor)); + convNode->inputIndex.emplace_back(graph->allTensors.size() - 1); + if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { + convNode->primitive->value.AsConv2D()->hasBias = true; + } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + convNode->primitive->value.AsDepthwiseConv2D()->hasBias = true; + } else { + MS_ASSERT(false); + } + + this->oldWeightTensor = nullptr; + this->newWeightTensor = nullptr; + this->newBiasTensor = nullptr; +} + +STATUS BatchNormFoldFusionPass::DeleteConstTensors() { + MS_ASSERT(graph != nullptr); + bool muFind = false; + bool sigmaFind = false; + bool betaFind = false; + bool gammaFind = false; + std::vector toDeleteTensorIdxes; + for (size_t i = 0; i < graph->allTensors.size(); i++) { + auto &tensor = graph->allTensors.at(i); + if (tensor.get() == muTensor) { + toDeleteTensorIdxes.emplace_back(i); + muFind = true; + this->muTensor = nullptr; + } + if (tensor.get() == sigmaTensor) { + toDeleteTensorIdxes.emplace_back(i); + sigmaFind = true; + this->sigmaTensor = nullptr; + } + if (tensor.get() == gammaTensor) { + toDeleteTensorIdxes.emplace_back(i); + gammaFind = true; + this->gammaTensor = nullptr; + } + if (tensor.get() == betaTensor) { + toDeleteTensorIdxes.emplace_back(i); + betaFind = true; + this->betaTensor = nullptr; + } + } + if (!muFind || !sigmaFind || !betaFind || !gammaFind) { + MS_LOG(ERROR) << "Can not find muTensor or sigmaTensor or betaTensor or gammaTensor in graph"; + return RET_ERROR; + } + auto status = RemoveTensor(graph, toDeleteTensorIdxes); + if (status != RET_OK) { + MS_LOG(ERROR) << "Remove ConstTensors failed" << bnFold->name.c_str(); + return RET_ERROR; + } + return RET_OK; +} + +BatchNormFoldFusionPass::~BatchNormFoldFusionPass() { + if (newWeightTensor == nullptr) { + newWeightTensor.reset(); + newWeightTensor = nullptr; + } + if (newBiasTensor == nullptr) { + newBiasTensor.reset(); + newBiasTensor = nullptr; + } +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/fusion/batchnorm_fold_fusion_pass.h b/mindspore/lite/tools/converter/optimizer/fusion/batchnorm_fold_fusion_pass.h new file mode 100644 index 00000000000..fea9f8a2586 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/batchnorm_fold_fusion_pass.h @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_BATCHNORM_FOLD_FUSION_PASS_H +#define MINDSPORE_PREDICT_BATCHNORM_FOLD_FUSION_PASS_H + +#include +#include +#include +#include "tools/converter/optimizer/fusion/fusion_pass.h" + +namespace mindspore { +namespace lite { +// input = input +// weight = SimQuantPerChannel(weight * gamma / sigma) +// bias = beta - gamma * mi / sigma +// MulFold: gamma sigma +// BatchNormFold: mi sigma +// AddFold: gamma beta mi sigma +class BatchNormFoldFusionPass : public FusionPass { + public: + BatchNormFoldFusionPass() = default; + + ~BatchNormFoldFusionPass() override; + + STATUS DefinePattern() override; + + STATUS DoFusion(MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(MetaGraphT *graph) override; + + protected: + STATUS FindNodes(MetaGraphT *graph, const std::unordered_map> &matchedPath); + STATUS CheckPath(MetaGraphT *graph, const std::unordered_map> &matchedPath); + STATUS FindTensors(); + STATUS GenNewWeightTensor(); + STATUS GenNewBiasTensor(); + STATUS IsolateNodes(MetaGraphT *graph, const std::unordered_map> &matchedPath); + void UpdateConvWeights(); + STATUS DeleteConstTensors(); + + protected: + MetaGraphT *graph = nullptr; + CNodeT *preConv = nullptr; + CNodeT *bnFold = nullptr; + CNodeT *mulFold = nullptr; + CNodeT *fakeNode = nullptr; + CNodeT *convNode = nullptr; + CNodeT *addFold = nullptr; + TensorT *muTensor = nullptr; + TensorT *sigmaTensor = nullptr; + TensorT *gammaTensor = nullptr; + TensorT *betaTensor = nullptr; + TensorT *oldWeightTensor = nullptr; + int32_t channelOut = 0; + + std::unique_ptr newWeightTensor = nullptr; + std::unique_ptr newBiasTensor = nullptr; + + std::string inputOpName = "Input"; + std::string convPatternOpName1 = "Convolution1"; + std::string bnFoldOpName = "BatchNormFold"; + std::string mulFoldOpName = "MulFold"; + std::string fakeQuantOpName = "FakeQuant"; + std::string convPatternOpName2 = "Convolution2"; + std::string addFoldOpName = "AddFold"; + std::string withPrePatternName = "BNFoldFusionWithPre"; + std::string noPrePatternName = "BNFoldFusionNoPre"; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_BATCHNORM_FOLD_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/optimizer/fusion/conv_activation_fusion_pass.cc b/mindspore/lite/tools/converter/optimizer/fusion/conv_activation_fusion_pass.cc new file mode 100644 index 00000000000..085d831d71d --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/conv_activation_fusion_pass.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/fusion/conv_activation_fusion_pass.h" +#include +#include +#include +#include "utils/log_adapter.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" +#include "tools/common/graph_util.h" +#include "src/common/op_utils.h" + +namespace mindspore { +namespace lite { +#define CONV_ACTIVATION_MATCH_PATH_LEN 2 + +STATUS ConvActivationFusionPass::DefinePattern() { + auto convOp = std::make_shared(); + convOp->id = kConvName; + convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; + auto actOp = std::make_shared(); + actOp->id = ACTIVATION_NAME; + actOp->types = {schema::PrimitiveType_Activation}; + actOp->left = convOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern("ConvActivationFusion")); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failed"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(convOp); + fusionPattern->AddPatternOp(actOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + + return RET_OK; +} + +// 1. change attr of conv +// 2. delete Activation node +STATUS ConvActivationFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + if (matchedPath.size() != CONV_ACTIVATION_MATCH_PATH_LEN) { + MS_LOG(ERROR) << "Conv-Activation-Fusion should have two NodeIndex in matchedPair"; + return RET_PARAM_INVALID; + } + + auto convPath = matchedPath[kConvName]; + auto actPath = matchedPath[ACTIVATION_NAME]; + auto &convNode = graph->nodes.at(convPath->nodeIdx); + auto &actNode = graph->nodes.at(actPath->nodeIdx); + + // todo if combine conv_relu_fusion and conv_relu6_fusion to conv_activation_fusion + if (actNode->primitive->value.AsActivation()->type != this->activationType) { + return RET_NO_CHANGE; + } + + if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { + convNode->primitive->value.AsConv2D()->activationType = this->activationType; + } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + convNode->primitive->value.AsDepthwiseConv2D()->activationType = this->activationType; + } else { + MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type; + return RET_ERROR; + } + + // remove activation node + MergeNodeAttrFromPost(convNode, actNode); + auto status = IsolateOneWayNode(graph, actPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, subGraph: " << actPath->subGraphIdx << ", node: " << actPath->nodeIdx + << ", error: " << status; + return status; + } + + return RET_OK; +} + +STATUS ConvActivationFusionPass::Run(schema::MetaGraphT *graph) { + SetActivationType(); + return FusionPass::Run(graph); +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/optimizer/fusion/conv_activation_fusion_pass.h b/mindspore/lite/tools/converter/optimizer/fusion/conv_activation_fusion_pass.h new file mode 100644 index 00000000000..748fd42c3b6 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/conv_activation_fusion_pass.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_CONV_ACTIVATION_FUSION_PASS_H +#define MINDSPORE_PREDICT_CONV_ACTIVATION_FUSION_PASS_H + +#include +#include +#include +#include "tools/converter/optimizer/fusion/fusion_pass.h" + +namespace mindspore { +namespace lite { +class ConvActivationFusionPass : public FusionPass { + public: + ConvActivationFusionPass() = default; + + ~ConvActivationFusionPass() override = default; + + STATUS DefinePattern() override; + + virtual STATUS SetActivationType() = 0; + + // 1. change attr of conv + // 2. delete Activation node + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; + + protected: + schema::ActivationType activationType = schema::ActivationType_RELU; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CONV_ACTIVATION_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.cc b/mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.cc new file mode 100644 index 00000000000..c353f2ca94e --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.cc @@ -0,0 +1,288 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.h" +#include +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" +#include "securec/include/securec.h" +// #include "utils/log_adapter.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" +#include "src/common/op_utils.h" + +namespace mindspore { +namespace lite { +#define CONV_BIASADD_MATCH_PATH_LEN 2 +#define BIASADD_OP_BIAS_INDEX_IN_WEIGHT 0 +#define BIASADD_OP_INPUT_NUM 2 +#define BIASADD_OP_CONST_TENSOR_INDEX 1 + +STATUS ConvBiasAddFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); } + +STATUS ConvBiasAddFusionPass::DefinePattern() { + auto convOp = std::make_shared(); + convOp->id = kConvName; + convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeConv2D}; + auto baOp = std::make_shared(); + baOp->id = BIASADD_NAME; + baOp->types = {schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Add}; + baOp->left = convOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern("ConvBiasAddFusion")); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failed"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(convOp); + fusionPattern->AddPatternOp(baOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + + return RET_OK; +} + +STATUS ConvBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + if (matchedPath.size() != CONV_BIASADD_MATCH_PATH_LEN) { + MS_LOG(ERROR) << "Conv-BiasAdd-Fusion should have two NodeIndex in matchedPair"; + return RET_PARAM_INVALID; + } + + auto convPath = matchedPath[kConvName]; + auto baPath = matchedPath[BIASADD_NAME]; + auto &convNode = graph->nodes.at(convPath->nodeIdx); + auto &baNode = graph->nodes.at(baPath->nodeIdx); + // add/biasadd node the second tensor is not constant tensor, don't fusion + auto baNodeInputIndex = baNode->inputIndex; + if (baNodeInputIndex.size() != BIASADD_OP_INPUT_NUM) { + MS_LOG(ERROR) << baNode->name.c_str() << " node tensors number is invalid! "; + return RET_ERROR; + } + auto baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex[BIASADD_OP_CONST_TENSOR_INDEX]).get(); + MS_ASSERT(baNodeBiasTensor != nullptr); + if (baNodeBiasTensor->refCount != schema::NodeType_ValueNode) { + // dont fusion, return + return RET_OK; + } + + // 1. generate newBiasTensor for conv + auto status = GenConvBiasTensor(convPath, baPath, graph); + if (RET_OK != status) { + MS_LOG(ERROR) << "GenConvBiasTensor failed, " << status; + return status; + } + if (this->newBiasTensor != nullptr) { + status = AddTensor2Node(graph, convPath->nodeIdx, std::move(this->newBiasTensor)); + this->newBiasTensor = nullptr; + if (status != RET_OK) { + MS_LOG(ERROR) << "AddTensor2Node failed, node: " << convPath->nodeIdx << ", error: " << status; + return status; + } + // add bias quantParam + // todo add quantParam for tensors + + // if (convNode->quantParam.size() == convNode->inputIndex.size() + convNode->outputIndex.size() - 1) { + // std::unique_ptr quantParamArray(new QuantParamArrayT()); + // if (quantParamArray == nullptr) { + // MS_LOG(ERROR) << "new QuantParamArrayT failed"); + // return RET_ERROR; + // } + // std::unique_ptr quantParam(new QuantParamT()); + // if (quantParam == nullptr) { + // MS_LOG(ERROR) << "new QuantParamT failed"); + // return RET_ERROR; + // } + // quantParam->numBits = -1; + // quantParam->scale = FLT_MAX; + // quantParam->zeroPoint = 0; + // quantParam->narrowRange = true; + // quantParam->min = FLT_MAX; + // quantParam->max = FLT_MAX; + // quantParamArray->param.emplace_back(quantParam.release()); + // convNode->quantParam.emplace_back(quantParamArray.release()); + // } + } + + // 2. change attr of conv + if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { + convNode->primitive->value.AsConv2D()->hasBias = true; + } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + convNode->primitive->value.AsDepthwiseConv2D()->hasBias = true; + } else if (convNode->primitive->value.type == schema::PrimitiveType_DeConv2D) { + convNode->primitive->value.AsDeConv2D()->hasBias = true; + } else { + MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type; + return RET_ERROR; + } + + // 5. delete BiasAdd node + MergeNodeAttrFromPost(convNode, baNode); + status = IsolateOneWayNode(graph, baPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, graph: %zu, node: %zu, error: %d"; + //, baPath->subGraphIdx, baPath->nodeIdx, status); + return status; + } + + return RET_OK; +} + +#define BIASADD_WEIGHT_SHAPE_SIZE 1 +#define BIASADD_BIAS_DIM_INDEX 0 + +STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr convPath, std::shared_ptr baPath, + MetaGraphT *graph) { + MS_ASSERT(convPath != nullptr); + MS_ASSERT(baPath != nullptr); + MS_ASSERT(graph != nullptr); + + auto convNode = graph->nodes.at(convPath->nodeIdx).get(); + MS_ASSERT(convNode != nullptr); + auto baNode = graph->nodes.at(baPath->nodeIdx).get(); + MS_ASSERT(baNode != nullptr); + int32_t kernelNum = 0; + if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { + kernelNum = convNode->primitive->value.AsConv2D()->channelOut; + } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + kernelNum = convNode->primitive->value.AsDepthwiseConv2D()->channelIn * + convNode->primitive->value.AsDepthwiseConv2D()->channelMultiplier; + } else if (convNode->primitive->value.type == schema::PrimitiveType_DeConv2D) { + kernelNum = convNode->primitive->value.AsDeConv2D()->channelOut; + } + auto convWeightTensorIdxes = convNode->inputIndex; + if (convWeightTensorIdxes.size() < CONV_OP_NO_BIAS_INPUT_NUM) { + MS_LOG(ERROR) << convNode->name.c_str() << " node tensors number is invalid! "; + return RET_ERROR; + } + convWeightTensorIdxes.erase(convWeightTensorIdxes.begin()); + auto baWeightTensorIdxes = baNode->inputIndex; + if (baWeightTensorIdxes.size() != BIASADD_OP_INPUT_NUM) { + MS_LOG(ERROR) << baNode->name.c_str() << " node tensors number is invalid! "; + return RET_ERROR; + } + baWeightTensorIdxes.erase(baWeightTensorIdxes.begin()); + + if (convWeightTensorIdxes.empty()) { + MS_LOG(ERROR) << "Conv2D should has one weight tensors at least, current number of weight tensors " + << convWeightTensorIdxes.size(); + return RET_ERROR; + } + + if (baWeightTensorIdxes.empty()) { + MS_LOG(ERROR) << "BiasAdd should has one weight tensors at least, current number of weight tensors " + << baWeightTensorIdxes.size(); + return RET_ERROR; + } + + TensorT *oldBiasTensor = nullptr; + TensorT *biasTensor = nullptr; + + if (convWeightTensorIdxes.size() == CONV_OP_HAS_BIAS_WEIGHT_NUM) { + oldBiasTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_BIAS_INDEX_IN_WEIGHT]).get(); + MS_ASSERT(oldBiasTensor != nullptr); + } + biasTensor = graph->allTensors.at(baWeightTensorIdxes.at(BIASADD_OP_BIAS_INDEX_IN_WEIGHT)).get(); + MS_ASSERT(biasTensor != nullptr); + auto biasDims = biasTensor->dims; + // if biasTensor is a scaler + if (biasDims.empty() && biasTensor->data.data() == nullptr) { + MS_LOG(ERROR) << "BiasAdd node %s bias tensor is invalid" << baNode->name.c_str(); + return RET_ERROR; + } + if (!biasDims.empty() && biasDims.size() != BIASADD_WEIGHT_SHAPE_SIZE) { + MS_LOG(ERROR) << "BiasAdd bias tensor should has one dimension, current number of dimension " << biasDims.size() + << ". or bias tensor is a scaler"; + return RET_ERROR; + } + if (!biasDims.empty() && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) { + MS_LOG(ERROR) << "Size(%d) of BiasAdd(%s) bias tensor should be equal to kernelNum(%d)" + << biasDims.at(BIASADD_BIAS_DIM_INDEX) << baNode->name.c_str() << kernelNum; + return RET_ERROR; + } + + // cal new biasData + this->newBiasData = new (std::nothrow) float[kernelNum]; + if (newBiasData == nullptr) { + MS_LOG(ERROR) << "new newBiasData failed"; + return RET_ERROR; + } + + if (biasDims.empty() && biasTensor->data.data() != nullptr) { + auto *biasData = reinterpret_cast(biasTensor->data.data()); + if (0 != memset_s(newBiasData, kernelNum * sizeof(float), *biasData, kernelNum * sizeof(float))) { + MS_LOG(ERROR) << "memset_s newBiasData failed"; + return RET_ERROR; + } + } else { + if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), biasTensor->data.data(), kernelNum * sizeof(float))) { + MS_LOG(ERROR) << "memcpy_s newBiasData failed"; + return RET_ERROR; + } + } + if (oldBiasTensor != nullptr) { + auto oldBiasDims = oldBiasTensor->dims; + if (oldBiasDims.size() != 1) { + MS_LOG(ERROR) + << "Conv bias tensor should has one dimension, current number of dimension %zu"; // oldBiasDims.size()); + return RET_ERROR; + } + if (oldBiasDims.at(0) != kernelNum) { + MS_LOG(ERROR) + << "Size(%zu) of Conv bias tensor should be equal to kernelNum(%d), current number of dimension %zu"; + // oldBiasDims.size(), kernelNum); + return RET_ERROR; + } + auto *oldBiasData = reinterpret_cast(oldBiasTensor->data.data()); + for (size_t i = 0; i < kernelNum; i++) { + oldBiasData[i] += newBiasData[i]; + } + } else { + auto *newCharBiasData = reinterpret_cast(newBiasData); + std::vector tmpBiasVec(newCharBiasData, newCharBiasData + kernelNum * sizeof(float) / sizeof(uint8_t)); + + auto weightTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_FILTER_INDEX_IN_WEIGHT]).get(); + this->newBiasTensor = std::unique_ptr(new (std::nothrow) TensorT); + // todo biasShape + this->newBiasTensor->dims = {kernelNum}; + this->newBiasTensor->dataType = weightTensor->dataType; + this->newBiasTensor->format = weightTensor->format; + this->newBiasTensor->refCount = weightTensor->refCount; + this->newBiasTensor->data.swap(tmpBiasVec); + newCharBiasData = nullptr; + } + + delete (this->newBiasData); + newBiasData = nullptr; + + return RET_OK; +} + +ConvBiasAddFusionPass::~ConvBiasAddFusionPass() { + if (this->newBiasData != nullptr) { + delete (this->newBiasData); + } +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.h b/mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.h new file mode 100644 index 00000000000..41426bfcc94 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/conv_biasadd_fusion_pass.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_CONV_BIASADD_FUSION_PASS_H +#define MINDSPORE_PREDICT_CONV_BIASADD_FUSION_PASS_H + +#include +#include +#include +#include "tools/converter/optimizer/fusion/fusion_pass.h" + +namespace mindspore { +namespace lite { +class ConvBiasAddFusionPass : public FusionPass { + public: + ConvBiasAddFusionPass() = default; + + ~ConvBiasAddFusionPass() override; + + STATUS DefinePattern() override; + + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; + + protected: + // gen this->newBiasTensor if conv has no bias before + STATUS GenConvBiasTensor(std::shared_ptr convPath, std::shared_ptr dstPath, schema::MetaGraphT *graph); + + protected: + float *newBiasData = nullptr; + std::unique_ptr newBiasTensor = nullptr; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CONV_BIASADD_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/optimizer/fusion/conv_bn_fusion_pass.cc b/mindspore/lite/tools/converter/optimizer/fusion/conv_bn_fusion_pass.cc new file mode 100644 index 00000000000..f8e20ef1e19 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/conv_bn_fusion_pass.cc @@ -0,0 +1,224 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "tools/converter/optimizer/fusion/conv_bn_fusion_pass.h" +#include "securec/include/securec.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace lite { +#define CAFFE_BATCHNORM_OP_WEIGHT_NUM 2 +#define TF_BATCHNORM_OP_WEIGHT_NUM 4 +#define CAFFE_BATCHNORM_MEAN_INDEX 0 +#define CAFFE_BATCHNORM_VARIANCE_INDEX 1 +#define TF_BATCHNORM_SCALE_INDEX 0 +#define TF_BATCHNORM_BIAS_INDEX 1 +#define TF_BATCHNORM_MEAN_INDEX 2 +#define TF_BATCHNORM_VARIANCE_INDEX 3 + +constexpr const float EPS = 1e-8; +constexpr const float EPS_DEFAULT_FLOAT = 1e-5; +constexpr const float POW_NUM = 0.5; + +STATUS ConvBNFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + return ConvScaleBiasFusionPass::DoFusion(graph, patternName, matchedPath); +} + +STATUS ConvBNFusionPass::DefinePattern() { + auto convOp = std::make_shared(); + convOp->id = kConvName; + convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; + auto bnOp = std::make_shared(); + bnOp->id = DST_NAME; + bnOp->types = {schema::PrimitiveType_FusedBatchNorm, schema::PrimitiveType_CaffeBatchNorm}; + bnOp->left = convOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern("ConvBatchNormFusion")); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failed"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(convOp); + fusionPattern->AddPatternOp(bnOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + + return RET_OK; +} + +STATUS ConvBNFusionPass::Run(schema::MetaGraphT *graph) { return ConvScaleBiasFusionPass::Run(graph); } + +STATUS ConvBNFusionPass::GetTransParam(schema::MetaGraphT *graph, std::shared_ptr bnPath, int32_t kernelNum) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(bnPath != nullptr); + + BNWeightTensors bnWeightTensors; + + auto status = GetBnWeightTensors(graph, bnPath, kernelNum, bnWeightTensors); + if (status != RET_OK) { + MS_LOG(ERROR) << "GetBnWeightTensors error " << status; + return status; + } + schema::TensorT *meanTensor = bnWeightTensors.meanTensor; + schema::TensorT *varianceTensor = bnWeightTensors.varianceTensor; + schema::TensorT *scaleTensor = bnWeightTensors.scaleTensor; + schema::TensorT *biasTensor = bnWeightTensors.biasTensor; + + auto *meanData = reinterpret_cast(meanTensor->data.data()); + auto *varianceData = reinterpret_cast(varianceTensor->data.data()); + + float eps = EPS_DEFAULT_FLOAT; + status = GetBnEpsilon(graph, bnPath, eps); + if (status != RET_OK) { + MS_LOG(ERROR) << "GetBnEpsilon failed " << status; + return status; + } + + // cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps) + if (memcpy_s(transScale, kernelNum * sizeof(float), varianceData, kernelNum * sizeof(float)) != 0) { + MS_LOG(ERROR) << "memcpy_s transScale error"; + return RET_ERROR; + } + // 1/sqrt(variance + eps) + for (int32_t i = 0; i < kernelNum; i++) { + float tmp = transScale[i] + eps; + tmp = pow(tmp, POW_NUM); + transScale[i] = 1 / tmp; + } + + if (scaleTensor != nullptr) { + auto *scaleData = reinterpret_cast(scaleTensor->data.data()); + // scale/sqrt(variance + eps) + for (int32_t i = 0; i < kernelNum; i++) { + transScale[i] *= scaleData[i]; + } + } + + // cal transBias, tf : -scale*mean/sqrt(variance + eps) + bias; caffe : -mean/sqrt(variance + eps) + // -mean/sqrt(variance + eps) + for (int32_t i = 0; i < kernelNum; i++) { + transBias[i] = -meanData[i] * transScale[i]; + } + + if (biasTensor != nullptr) { + auto *biasData = reinterpret_cast(biasTensor->data.data()); + // -scale*mean/sqrt(variance + eps) + bias + for (int32_t i = 0; i < kernelNum; i++) { + transBias[i] += biasData[i]; + } + } + + return RET_OK; +} + +// BatchNorm weight Tensor definition: +// caffe +// estimated_mean --0 +// estimated_variance --1 +// tensorflow +// scale -- 0 +// bias --1 +// estimated_mean --2 +// estimated_variance --3 +STATUS ConvBNFusionPass::GetBnWeightTensors(schema::MetaGraphT *graph, std::shared_ptr bnPath, int32_t kernelNum, + BNWeightTensors &bnWeightTensors) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(bnPath != nullptr); + auto bnNode = graph->nodes.at(bnPath->nodeIdx).get(); + auto bnWeightTensorIdxes = bnNode->inputIndex; + bnWeightTensorIdxes.erase(bnWeightTensorIdxes.begin()); + if (bnWeightTensorIdxes.size() == CAFFE_BATCHNORM_OP_WEIGHT_NUM) { + bnWeightTensors.meanTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_MEAN_INDEX]).get(); + bnWeightTensors.varianceTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_VARIANCE_INDEX]).get(); + } else if (bnWeightTensorIdxes.size() == TF_BATCHNORM_OP_WEIGHT_NUM) { + bnWeightTensors.scaleTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_SCALE_INDEX]).get(); + bnWeightTensors.biasTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_BIAS_INDEX]).get(); + bnWeightTensors.meanTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_MEAN_INDEX]).get(); + bnWeightTensors.varianceTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_VARIANCE_INDEX]).get(); + } else { + MS_LOG(ERROR) << "BatchNorm should has " << CAFFE_BATCHNORM_OP_WEIGHT_NUM << " or " << TF_BATCHNORM_OP_WEIGHT_NUM + << " weight tensors, current number of weight tensors " << bnWeightTensorIdxes.size(); + return RET_ERROR; + } + + if (bnWeightTensors.meanTensor == nullptr) { + MS_LOG(ERROR) << "BatchNorm's mean tensor is nullptr"; + return RET_ERROR; + } + + if (bnWeightTensors.varianceTensor == nullptr) { + MS_LOG(ERROR) << "BatchNorm's variance tensor is nullptr"; + return RET_ERROR; + } + + if (kernelNum != bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { + MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size(" + << bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")"; + return RET_ERROR; + } + + if (kernelNum != bnWeightTensors.varianceTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { + MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size(" + << bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")"; + return RET_ERROR; + } + + if (bnWeightTensors.scaleTensor != nullptr) { + if (kernelNum != bnWeightTensors.scaleTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { + MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size(" + << bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")"; + return RET_ERROR; + } + } + + if (bnWeightTensors.biasTensor != nullptr) { + if (kernelNum != bnWeightTensors.biasTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { + MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size(" + << bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")"; + return RET_ERROR; + } + } + return RET_OK; +} + +STATUS ConvBNFusionPass::GetBnEpsilon(schema::MetaGraphT *graph, std::shared_ptr bnPath, float &eps) { + MS_ASSERT(graph != nullptr); + auto bnNode = graph->nodes.at(bnPath->nodeIdx).get(); + MS_ASSERT(bnNode != nullptr); + if (bnNode->primitive->value.type == schema::PrimitiveType_FusedBatchNorm) { + eps = bnNode->primitive->value.AsFusedBatchNorm()->epsilon; + } else if (bnNode->primitive->value.type == schema::PrimitiveType_CaffeBatchNorm) { + eps = bnNode->primitive->value.AsCaffeBatchNorm()->epsilon; + } else { + MS_LOG(ERROR) << "match pattern has error, " << bnNode->name.c_str() << " not BatchNorm node"; + return RET_ERROR; + } + + if (eps < EPS) { + eps = EPS_DEFAULT_FLOAT; + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/optimizer/fusion/conv_bn_fusion_pass.h b/mindspore/lite/tools/converter/optimizer/fusion/conv_bn_fusion_pass.h new file mode 100644 index 00000000000..9e541ff82b7 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/conv_bn_fusion_pass.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#ifndef MINDSPORE_CONV_BN_FUSION_PASS_H +#define MINDSPORE_CONV_BN_FUSION_PASS_H + +#include "tools/converter/optimizer/fusion/conv_bn_fusion_pass.h" +#include "tools/converter/optimizer/fusion/conv_scale_bias_fusion_pass.h" + +namespace mindspore { +namespace lite { +class ConvBNFusionPass : public ConvScaleBiasFusionPass { + public: + ConvBNFusionPass() = default; + + ~ConvBNFusionPass() override = default; + + STATUS DefinePattern() override; + + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; + + protected: + STATUS GetTransParam(schema::MetaGraphT *graph, std::shared_ptr bnPath, int32_t kernelNum) override; + + // Get and check BNNode weight tensor + STATUS GetBnWeightTensors(schema::MetaGraphT *graph, std::shared_ptr bnPath, int32_t kernelNum, + BNWeightTensors &bnWeightTensors); + + STATUS GetBnEpsilon(schema::MetaGraphT *graph, std::shared_ptr bnPath, float &eps); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CONV_BN_FUSION_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/fusion/conv_relu6_fusion_pass.cc b/mindspore/lite/tools/converter/optimizer/fusion/conv_relu6_fusion_pass.cc new file mode 100644 index 00000000000..6b53bdc7a2e --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/conv_relu6_fusion_pass.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/optimizer/fusion/conv_relu6_fusion_pass.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +STATUS ConvRelu6FusionPass::DefinePattern() { return ConvActivationFusionPass::DefinePattern(); } + +STATUS ConvRelu6FusionPass::SetActivationType() { + this->activationType = ActivationType_RELU6; + return RET_OK; +} + +STATUS ConvRelu6FusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + return ConvActivationFusionPass::DoFusion(graph, patternName, matchedPath); +} + +STATUS ConvRelu6FusionPass::Run(MetaGraphT *graph) { return ConvActivationFusionPass::Run(graph); } + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/optimizer/fusion/conv_relu6_fusion_pass.h b/mindspore/lite/tools/converter/optimizer/fusion/conv_relu6_fusion_pass.h new file mode 100644 index 00000000000..ada49e8b8dd --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/conv_relu6_fusion_pass.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_CONV_RELU6_FUSION_PASS_H +#define MINDSPORE_PREDICT_CONV_RELU6_FUSION_PASS_H + +#include "tools/converter/optimizer/fusion/conv_activation_fusion_pass.h" +#include +#include +#include + +namespace mindspore { +namespace lite { +class ConvRelu6FusionPass : public ConvActivationFusionPass { + public: + ConvRelu6FusionPass() = default; + + ~ConvRelu6FusionPass() override = default; + + STATUS DefinePattern() override; + + STATUS SetActivationType() override; + + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CONV_RELU6_FUSION_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/fusion/conv_relu_fusion_pass.cc b/mindspore/lite/tools/converter/optimizer/fusion/conv_relu_fusion_pass.cc new file mode 100644 index 00000000000..2e0c45a1ed6 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/conv_relu_fusion_pass.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/optimizer/fusion/conv_relu_fusion_pass.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +STATUS ConvReluFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + return ConvActivationFusionPass::DoFusion(graph, patternName, matchedPath); +} + +STATUS ConvReluFusionPass::Run(schema::MetaGraphT *graph) { return ConvActivationFusionPass::Run(graph); } + +STATUS ConvReluFusionPass::SetActivationType() { + this->activationType = schema::ActivationType_RELU; + return RET_OK; +} + +STATUS ConvReluFusionPass::DefinePattern() { return ConvActivationFusionPass::DefinePattern(); } +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/optimizer/fusion/conv_relu_fusion_pass.h b/mindspore/lite/tools/converter/optimizer/fusion/conv_relu_fusion_pass.h new file mode 100644 index 00000000000..3af33940896 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/conv_relu_fusion_pass.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_CONV_RELU_FUSION_PASS_H +#define MINDSPORE_PREDICT_CONV_RELU_FUSION_PASS_H + +#include "tools/converter/optimizer/fusion/conv_activation_fusion_pass.h" +#include +#include +#include + +namespace mindspore { +namespace lite { +class ConvReluFusionPass : public ConvActivationFusionPass { + public: + ConvReluFusionPass() = default; + + ~ConvReluFusionPass() override = default; + + STATUS DefinePattern() override; + + STATUS SetActivationType() override; + + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CONV_RELU_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/optimizer/fusion/conv_scale_bias_fusion_pass.cc b/mindspore/lite/tools/converter/optimizer/fusion/conv_scale_bias_fusion_pass.cc new file mode 100644 index 00000000000..5ed5e555765 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/conv_scale_bias_fusion_pass.cc @@ -0,0 +1,361 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2018-2019. All rights reserved. + * Description: mslite + * Author: mslite + * Create: 2019-12-13 + */ + +#include +#include +#include +#include +#include +#include +#include "tools/converter/optimizer/fusion/conv_scale_bias_fusion_pass.h" +#include "securec/include/securec.h" +#include "utils/log_adapter.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" +#include "src/common/op_utils.h" +#include "tools/common/graph_util.h" +#include "tools/common/tensor_util.h" + +namespace mindspore { +namespace lite { + +#define CONV_SCALE_BIAS_MATCH_PATH_LEN 2 + +// 1. generate biasTensor according to BN weightTensor +// 2. change attr of conv +// 3. delete BN node +STATUS ConvScaleBiasFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + if (matchedPath.size() != CONV_SCALE_BIAS_MATCH_PATH_LEN) { + MS_LOG(ERROR) << "Conv-Scale-Bias-Fusion should have two NodeIndex in matchedPair"; + return RET_PARAM_INVALID; + } + + auto convPath = matchedPath[kConvName]; + MS_ASSERT(convPath != nullptr); + auto dstPath = matchedPath[DST_NAME]; + MS_ASSERT(dstPath != nullptr); + MS_ASSERT(subGraph != nullptr); + auto &convNode = graph->nodes.at(convPath->nodeIdx); + MS_ASSERT(convNode != nullptr); + auto &dstNode = graph->nodes.at(dstPath->nodeIdx); + MS_ASSERT(dstNode != nullptr); + + // 1. generate new weightTensor and biasTensor for conv + auto status = GenConvWeightTensors(graph, convPath, dstPath); + if (RET_OK != status) { + MS_LOG(ERROR) << "GenConvWeightTensors failed, " << status; + return status; + } + if (convNode->inputIndex.size() == CONV_OP_HAS_BIAS_INPUT_NUM) { + status = ReplaceTensorOfNode(graph, convPath->nodeIdx, convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT), + std::move(this->newWeightTensor)); + this->newWeightTensor = nullptr; + if (status != RET_OK) { + MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx + << ", node: " << convPath->nodeIdx << ", tensor " + << convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status; + return status; + } + status = ReplaceTensorOfNode(graph, convPath->nodeIdx, convNode->inputIndex.at(CONV_OP_BIAS_INDEX_IN_INPUT), + std::move(this->newBiasTensor)); + this->newBiasTensor = nullptr; + if (status != RET_OK) { + MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx + << ", node: " << convPath->nodeIdx << ", tensor " + << convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status; + return status; + } + } else if (convNode->inputIndex.size() == CONV_OP_NO_BIAS_INPUT_NUM) { + status = ReplaceTensorOfNode(graph, convPath->nodeIdx, convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT), + std::move(this->newWeightTensor)); + this->newWeightTensor = nullptr; + if (status != RET_OK) { + MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx + << ", node: " << convPath->nodeIdx << ", tensor " + << convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status; + return status; + } + status = AddTensor2Node(graph, convPath->nodeIdx, std::move(this->newBiasTensor)); + this->newBiasTensor = nullptr; + if (status != RET_OK) { + MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx + << ", node: " << convPath->nodeIdx << ", tensor " + << convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status; + return status; + } + // if (convNode->name == "Conv_461") { + // } + // add bias quantParam + // todo use tensor quant param + // if (convNode->quantParam.size() == convNode->inputIndex.size() + convNode->outputIndex.size() - 1) { + // std::unique_ptr quantParamArray(new QuantParamArrayT()); + // if (quantParamArray == nullptr) { + // MS_LOG(ERROR) << "new QuantParamArrayT failed"; + // return RET_ERROR; + // } + // std::unique_ptr quantParam(new QuantParamT()); + // if (quantParam == nullptr) { + // MS_LOG(ERROR) << "new QuantParamT failed"; + // return RET_ERROR; + // } + // quantParam->numBits = -1; + // quantParam->scale = FLT_MAX; + // quantParam->zeroPoint = 0; + // quantParam->narrowRange = true; + // quantParam->min = FLT_MAX; + // quantParam->max = FLT_MAX; + // quantParamArray->param.emplace_back(quantParam.release()); + // convNode->quantParam.emplace_back(quantParamArray.release()); + // } + } else { + MS_LOG(ERROR) << "Conv node should has 2 or 3 weight tensors rather than " << convNode->inputIndex.size(); + return RET_ERROR; + } + + // 2. change attr of conv + if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { + convNode->primitive->value.AsConv2D()->hasBias = true; + } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + convNode->primitive->value.AsDepthwiseConv2D()->hasBias = true; + } else { + MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type; + return RET_ERROR; + } + + // 3. delete DST node + MergeNodeAttrFromPost(convNode, dstNode); + status = IsolateOneWayNode(graph, dstPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstPath->nodeIdx << ", error: " << status; + return status; + } + + return RET_OK; +} + +STATUS ConvScaleBiasFusionPass::GenConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr &convPath, + std::shared_ptr dstPath) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(convPath != nullptr); + MS_ASSERT(dstPath != nullptr); + MS_ASSERT(subGraph != nullptr); + auto &convNode = graph->nodes.at(convPath->nodeIdx); + MS_ASSERT(convNode != nullptr); + int32_t kernelNum = -1; + if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { + kernelNum = convNode->primitive->value.AsConv2D()->channelOut; + } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { + kernelNum = convNode->primitive->value.AsDepthwiseConv2D()->channelMultiplier * + convNode->primitive->value.AsDepthwiseConv2D()->channelIn; + } else { + MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type; + return RET_ERROR; + } + if (kernelNum <= 0) { + MS_LOG(ERROR) << "KernelNum should be positive, " << kernelNum; + return RET_ERROR; + } + + this->transScale = new (std::nothrow) float[kernelNum]; + this->transBias = new (std::nothrow) float[kernelNum]; + + if (transScale == nullptr) { + MS_LOG(ERROR) << "new transScale failed"; + return RET_ERROR; + } + + if (transBias == nullptr) { + MS_LOG(ERROR) << "new transBias failed"; + return RET_ERROR; + } + + if (0 != memset_s(transScale, kernelNum * sizeof(float), 0, kernelNum * sizeof(float))) { + MS_LOG(ERROR) << "memset transScale failed"; + return RET_ERROR; + } + + if (0 != memset_s(transBias, kernelNum * sizeof(float), 0, kernelNum * sizeof(float))) { + MS_LOG(ERROR) << "memset transBias failed"; + return RET_ERROR; + } + + auto status = GetTransParam(graph, dstPath, kernelNum); + if (RET_OK != status) { + MS_LOG(ERROR) << "GetTransParam failed, " << status; + return status; + } + + status = CalConvWeightTensors(graph, convPath, kernelNum); + if (RET_OK != status) { + MS_LOG(ERROR) << "GenConvWeightTensors failed, " << status; + return status; + } + return RET_OK; +} + +STATUS ConvScaleBiasFusionPass::CalNewWeightTensor(TensorT *oldWeightTensor, const int32_t kernelNum, + const size_t kernelSize) { + MS_ASSERT(oldWeightTensor != nullptr); + auto weightData = reinterpret_cast(oldWeightTensor->data.data()); + size_t kernelDataCount = kernelNum * kernelSize; + if (kernelDataCount == 0) { + MS_LOG(ERROR) << "KernelDataCount should be positive, " << kernelDataCount; + return RET_ERROR; + } + this->newWeightData = new (std::nothrow) float[kernelDataCount]; + if (newWeightData == nullptr) { + MS_LOG(ERROR) << "new newWeightData failed"; + return RET_ERROR; + } + + if (0 != memset_s(newWeightData, kernelDataCount * sizeof(float), 0, kernelDataCount * sizeof(float))) { + MS_LOG(ERROR) << "memset newWeightData failed"; + return RET_ERROR; + } + + for (size_t i = 0; i < kernelNum; i++) { + for (size_t j = 0; j < kernelSize; j++) { + newWeightData[i * kernelSize + j] = weightData[i * kernelSize + j] * transScale[i]; + } + } + auto newCharWeightData = reinterpret_cast(newWeightData); + std::vector tmpWeightVec(newCharWeightData, + newCharWeightData + kernelDataCount * sizeof(float) / sizeof(uint8_t)); + + this->newWeightTensor = std::unique_ptr(new (std::nothrow) TensorT); + if (this->newWeightTensor == nullptr) { + MS_LOG(ERROR) << "new newWeightTensor failed"; + return RET_ERROR; + } + this->newWeightTensor->dims.insert(this->newWeightTensor->dims.begin(), oldWeightTensor->dims.begin(), + oldWeightTensor->dims.end()); + this->newWeightTensor->dataType = oldWeightTensor->dataType; + this->newWeightTensor->format = oldWeightTensor->format; + this->newWeightTensor->refCount = oldWeightTensor->refCount; + this->newWeightTensor->data.swap(tmpWeightVec); + delete (this->newWeightData); + newWeightData = nullptr; + + return RET_OK; +} + +STATUS ConvScaleBiasFusionPass::CalNewBiasTensor(TensorT *oldWeightTensor, TensorT *oldBiasTensor, + const int32_t kernelNum) { + MS_ASSERT(oldWeightTensor != nullptr); + this->newBiasData = new (std::nothrow) float[kernelNum]; + if (newBiasData == nullptr) { + MS_LOG(ERROR) << "new newBiasData failed"; + return RET_ERROR; + } + if (0 != memset_s(newBiasData, kernelNum * sizeof(float), 0, kernelNum * sizeof(float))) { + MS_LOG(ERROR) << "memset newBiasData failed"; + return RET_ERROR; + } + + if (oldBiasTensor != nullptr) { + auto *biasData = reinterpret_cast(oldBiasTensor->data.data()); + + for (size_t i = 0; i < kernelNum; i++) { + this->newBiasData[i] = biasData[i] * transScale[i] + transBias[i]; + } + } else { + if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), transBias, kernelNum * sizeof(float))) { + MS_LOG(ERROR) << "memcpy_s newBiasData failed"; + return RET_ERROR; + } + } + auto *newCharBiasData = reinterpret_cast(newBiasData); + std::vector tmpBiasVec(newCharBiasData, newCharBiasData + kernelNum * sizeof(float) / sizeof(uint8_t)); + + this->newBiasTensor = std::unique_ptr(new (std::nothrow) TensorT); + if (this->newBiasTensor == nullptr) { + MS_LOG(ERROR) << "new newBiasTensor failed"; + return RET_ERROR; + } + // todo biasShape + this->newBiasTensor->dims = {kernelNum}; + this->newBiasTensor->dataType = oldWeightTensor->dataType; + this->newBiasTensor->format = oldWeightTensor->format; + this->newBiasTensor->refCount = oldWeightTensor->refCount; + this->newBiasTensor->data.swap(tmpBiasVec); + delete (this->newBiasData); + newCharBiasData = nullptr; + newBiasData = nullptr; + return RET_OK; +} + +STATUS ConvScaleBiasFusionPass::CalConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr &convPath, + int32_t kernelNum) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(convPath != nullptr); + + auto convNode = graph->nodes.at(convPath->nodeIdx).get(); + MS_ASSERT(convNode != nullptr); + auto convWeightTensorIdxes = convNode->inputIndex; + convWeightTensorIdxes.erase(convWeightTensorIdxes.begin()); + + TensorT *weightTensor = nullptr; + TensorT *biasTensor = nullptr; + if (convWeightTensorIdxes.size() == CONV_OP_NO_BIAS_WEIGHT_NUM) { + weightTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_FILTER_INDEX_IN_WEIGHT]).get(); + } else if (convWeightTensorIdxes.size() == CONV_OP_HAS_BIAS_WEIGHT_NUM) { + weightTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_FILTER_INDEX_IN_WEIGHT]).get(); + biasTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_BIAS_INDEX_IN_WEIGHT]).get(); + } else { + MS_LOG(ERROR) << "Conv2D should has " << CONV_OP_NO_BIAS_WEIGHT_NUM << " or " << CONV_OP_HAS_BIAS_WEIGHT_NUM + << " weight tensors, current number of weight tensors " << convWeightTensorIdxes.size(); + return RET_ERROR; + } + if (weightTensor == nullptr) { + MS_LOG(ERROR) << "Conv2D's weight tensor is nullptr"; + return RET_ERROR; + } + + auto weightShape = weightTensor->dims; + if (weightShape.size() != CONV_FILTER_SHAPE_SIZE) { + MS_LOG(ERROR) << "Size of dims of weight tensor should be " << CONV_FILTER_SHAPE_SIZE << " rather than " + << weightShape.size(); + return RET_ERROR; + } + size_t kernelSize = GetShapeSize(*weightTensor) / kernelNum; + + // cal new weightData + auto status = CalNewWeightTensor(weightTensor, kernelNum, kernelSize); + if (status != RET_OK) { + MS_LOG(ERROR) << "CalNewWeightTensor error " << status; + return status; + } + // cal new biasData + status = CalNewBiasTensor(weightTensor, biasTensor, kernelNum); + if (status != RET_OK) { + MS_LOG(ERROR) << "CalNewBiasTensor error " << status; + return status; + } + return RET_OK; +} + +STATUS ConvScaleBiasFusionPass::Run(schema::MetaGraphT *graph) { return FusionPass::Run(graph); } + +ConvScaleBiasFusionPass::~ConvScaleBiasFusionPass() { + if (this->transScale != nullptr) { + delete (this->transScale); + } + if (this->transBias != nullptr) { + delete (this->transBias); + } + if (this->newWeightData != nullptr) { + delete (this->newWeightData); + } + if (this->newBiasData != nullptr) { + delete (this->newBiasData); + } +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/optimizer/fusion/conv_scale_bias_fusion_pass.h b/mindspore/lite/tools/converter/optimizer/fusion/conv_scale_bias_fusion_pass.h new file mode 100644 index 00000000000..325fad6a0b9 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/conv_scale_bias_fusion_pass.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2018-2019. All rights reserved. + * Description: mslite + * Author: mslite + * Create: 2019-12-13 + */ + +#ifndef MINDSPORE_PREDICT_CONV_SCALE_BIAS_FUSION_PASS_H +#define MINDSPORE_PREDICT_CONV_SCALE_BIAS_FUSION_PASS_H + +#include +#include +#include +#include "tools/converter/optimizer/fusion/fusion_pass.h" + +namespace mindspore { +namespace lite { +struct BNWeightTensors { + schema::TensorT *meanTensor = nullptr; + schema::TensorT *varianceTensor = nullptr; + schema::TensorT *scaleTensor = nullptr; + schema::TensorT *biasTensor = nullptr; +}; + +class ConvScaleBiasFusionPass : public FusionPass { + public: + ConvScaleBiasFusionPass() = default; + + ~ConvScaleBiasFusionPass() override; + + STATUS DefinePattern() override = 0; + + // 1. generate biasTensor according to BN weightTensor + // 2. change attr of conv + // 3. delete BN node + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; + + protected: + // call GetTransParam() and CalConvWeightTensors() + STATUS GenConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr &convPath, + std::shared_ptr dstPath); + + // fill this->transScale and this->transBias + virtual STATUS GetTransParam(schema::MetaGraphT *graph, std::shared_ptr dstPath, int32_t kernelNum) = 0; + + // fill this->newWeightTensor and this->newBiasTensor according to this->transScale and this->transBias + STATUS CalConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr &convPath, int32_t kernelNum); + + STATUS CalNewWeightTensor(schema::TensorT *oldWeightTensor, int32_t kernelNum, size_t kernelSize); + + STATUS CalNewBiasTensor(schema::TensorT *oldWeightTensor, schema::TensorT *oldBiasTensor, int32_t kernelNum); + + protected: + float *transScale = nullptr; + float *transBias = nullptr; + float *newWeightData = nullptr; + float *newBiasData = nullptr; + std::unique_ptr newWeightTensor = nullptr; + std::unique_ptr newBiasTensor = nullptr; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CONV_SCALE_BIAS_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/optimizer/fusion/conv_scale_fusion_pass.cc b/mindspore/lite/tools/converter/optimizer/fusion/conv_scale_fusion_pass.cc new file mode 100644 index 00000000000..fb02626861c --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/conv_scale_fusion_pass.cc @@ -0,0 +1,126 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/optimizer/fusion/conv_scale_fusion_pass.h" +#include "securec/include/securec.h" +#include "utils/log_adapter.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +#define SCALE_OP_NO_BIAS_WEIGHT_NUM 1 +#define SCALE_OP_HAS_BIAS_WEIGHT_NUM 2 + +#define SCALE_OP_SCALE_INDEX_IN_WEIGHT 0 +#define SCALE_OP_BIAS_INDEX_IN_WEIGHT 1 + +STATUS ConvScaleFusionPass::DefinePattern() { + auto convOp = std::make_shared(); + convOp->id = kConvName; + convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; + auto scaleOp = std::make_shared(); + scaleOp->id = DST_NAME; + scaleOp->types = {schema::PrimitiveType_Scale}; + scaleOp->left = convOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern("ConvScaleFusion")); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failed"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(convOp); + fusionPattern->AddPatternOp(scaleOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + + return RET_OK; +} + +STATUS ConvScaleFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + return ConvScaleBiasFusionPass::DoFusion(graph, patternName, matchedPath); +} + +STATUS ConvScaleFusionPass::Run(schema::MetaGraphT *graph) { return ConvScaleBiasFusionPass::Run(graph); } + +STATUS ConvScaleFusionPass::GetTransParam(schema::MetaGraphT *graph, std::shared_ptr scalePath, + int32_t kernelNum) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(scalePath != nullptr); + + auto scaleNode = graph->nodes.at(scalePath->nodeIdx).get(); + MS_ASSERT(scaleNode != nullptr); + auto scaleWeightTensorIdxes = scaleNode->inputIndex; + scaleWeightTensorIdxes.erase(scaleWeightTensorIdxes.begin()); + + schema::TensorT *scaleTensor = nullptr; + schema::TensorT *biasTensor = nullptr; + + if (scaleWeightTensorIdxes.size() == SCALE_OP_NO_BIAS_WEIGHT_NUM) { + scaleTensor = graph->allTensors.at(scaleWeightTensorIdxes[SCALE_OP_SCALE_INDEX_IN_WEIGHT]).get(); + } else if (scaleWeightTensorIdxes.size() == SCALE_OP_HAS_BIAS_WEIGHT_NUM) { + scaleTensor = graph->allTensors.at(scaleWeightTensorIdxes[SCALE_OP_SCALE_INDEX_IN_WEIGHT]).get(); + biasTensor = graph->allTensors.at(scaleWeightTensorIdxes[SCALE_OP_BIAS_INDEX_IN_WEIGHT]).get(); + } else { + MS_LOG(ERROR) << "Scale should has %d or %d weight tensors, current number of weight tensors %zu"; + // SCALE_OP_NO_BIAS_WEIGHT_NUM, SCALE_OP_HAS_BIAS_WEIGHT_NUM, scaleWeightTensorIdxes.size()); + return RET_ERROR; + } + + if (scaleTensor == nullptr) { + MS_LOG(ERROR) << "Scale's scale tensor is nullptr"; + return RET_ERROR; + } + + if (kernelNum != scaleTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { + MS_LOG(ERROR) << "conv kernel num %u is expected to be equal to scale size(%lu)"; + //, kernelNum, scaleTensor->data.size() * sizeof(uint8_t) / sizeof(float)); + return RET_ERROR; + } + + const float *scaleData = reinterpret_cast(scaleTensor->data.data()); + + if (0 != memcpy_s(transScale, kernelNum * sizeof(float), scaleData, kernelNum * sizeof(float))) { + MS_LOG(ERROR) << "memcpy_s transScale failed"; + return RET_ERROR; + } + + if (biasTensor != nullptr) { + if (kernelNum != biasTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { + MS_LOG(ERROR) << "conv kernel num %u is expected to be equal to bias size(%lu)"; + //, kernelNum, biasTensor->data.size() * sizeof(uint8_t) / sizeof(float)); + return RET_ERROR; + } + + const float *biasData = reinterpret_cast(biasTensor->data.data()); + + if (0 != memcpy_s(transBias, kernelNum * sizeof(float), biasData, kernelNum * sizeof(float))) { + MS_LOG(ERROR) << "memcpy_s transBias failed"; + return RET_ERROR; + } + } + + return RET_OK; +} +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/optimizer/fusion/conv_scale_fusion_pass.h b/mindspore/lite/tools/converter/optimizer/fusion/conv_scale_fusion_pass.h new file mode 100644 index 00000000000..8c73f01fe2d --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/conv_scale_fusion_pass.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_CONV_SCALE_FUSION_PASS_H +#define MINDSPORE_PREDICT_CONV_SCALE_FUSION_PASS_H + +#include "tools/converter/optimizer/fusion/conv_scale_bias_fusion_pass.h" +#include +#include +#include + +namespace mindspore { +namespace lite { +class ConvScaleFusionPass : public ConvScaleBiasFusionPass { + public: + ConvScaleFusionPass() = default; + + ~ConvScaleFusionPass() override = default; + + STATUS DefinePattern() override; + + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; + + private: + STATUS GetTransParam(schema::MetaGraphT *graph, std::shared_ptr scalePath, int32_t kernelNum) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_CONV_SCALE_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/optimizer/fusion/format_trans_fusion_pass.cc b/mindspore/lite/tools/converter/optimizer/fusion/format_trans_fusion_pass.cc new file mode 100644 index 00000000000..9fb0b75aa42 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/format_trans_fusion_pass.cc @@ -0,0 +1,185 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/optimizer/fusion/format_trans_fusion_pass.h" +#include "utils/log_adapter.h" +#include "securec/include/securec.h" +// #include "utils/log_adapter.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "mindspore/lite/schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +#define kFormatTransMatchPathLen2 2 +#define kFormatTransMatchPathLen3 3 + +STATUS FormatTransFusionPass::DefinePattern() { + // nchw2nhwc + nhwc2nchw + { + auto nc2nhOp = std::make_shared(); + nc2nhOp->id = kFormatTransNc2NhOp; + nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; + auto nh2ncOp = std::make_shared(); + nh2ncOp->id = kFormatTransNh2NcOp; + nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; + + nh2ncOp->left = nc2nhOp; + std::unique_ptr nc2NhAndNh2NcFusionPattern(new (std::nothrow) + FusionPattern(kNc2NhAndNh2NcFusionPattern)); + if (nc2NhAndNh2NcFusionPattern == nullptr) { + // MS_LOG(ERROR) << "new %s failed", kNc2NhAndNh2NcFusionPattern); + return RET_ERROR; + } + nc2NhAndNh2NcFusionPattern->AddPatternOp(nc2nhOp); + nc2NhAndNh2NcFusionPattern->AddPatternOp(nh2ncOp); + nc2NhAndNh2NcFusionPattern->Finish(); + this->patterns.emplace_back(nc2NhAndNh2NcFusionPattern.release()); + } + // nchw2nhwc + QuantDtypeCast + nhwc2nchw + { + auto nc2nhOp = std::make_shared(); + nc2nhOp->id = kFormatTransNc2NhOp; + nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; + auto passOp = std::make_shared(); + passOp->id = kFormatTransPassOp; + passOp->types = {PrimitiveType_QuantDTypeCast}; + auto nh2ncOp = std::make_shared(); + nh2ncOp->id = kFormatTransNh2NcOp; + nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; + + passOp->left = nc2nhOp; + nh2ncOp->left = passOp; + std::unique_ptr nc2NhAndNh2NcPassFusionPattern(new FusionPattern(kNc2NhAndNh2NcPassFusionPattern)); + if (nc2NhAndNh2NcPassFusionPattern == nullptr) { + // MS_LOG(ERROR) << "new %s failed", kNc2NhAndNh2NcPassFusionPattern); + return RET_ERROR; + } + nc2NhAndNh2NcPassFusionPattern->AddPatternOp(nc2nhOp); + nc2NhAndNh2NcPassFusionPattern->AddPatternOp(passOp); + nc2NhAndNh2NcPassFusionPattern->AddPatternOp(nh2ncOp); + nc2NhAndNh2NcPassFusionPattern->Finish(); + this->patterns.emplace_back(nc2NhAndNh2NcPassFusionPattern.release()); + } + // nhwc2nchw + nchw2nhwc + { + auto nc2nhOp = std::make_shared(); + nc2nhOp->id = kFormatTransNc2NhOp; + nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; + auto nh2ncOp = std::make_shared(); + nh2ncOp->id = kFormatTransNh2NcOp; + nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; + + nc2nhOp->left = nh2ncOp; + std::unique_ptr nh2NcAndNc2NhFusionPattern(new (std::nothrow) + FusionPattern(kNh2NcAndNc2NhFusionPattern)); + if (nh2NcAndNc2NhFusionPattern == nullptr) { + // MS_LOG(ERROR) << "new %s failed", kNh2NcAndNc2NhFusionPattern); + return RET_ERROR; + } + nh2NcAndNc2NhFusionPattern->AddPatternOp(nh2ncOp); + nh2NcAndNc2NhFusionPattern->AddPatternOp(nc2nhOp); + nh2NcAndNc2NhFusionPattern->Finish(); + this->patterns.emplace_back(nh2NcAndNc2NhFusionPattern.release()); + } + // nhwc2nchw + QuantDtypeCast + nchw2nhwc + { + auto nc2nhOp = std::make_shared(); + nc2nhOp->id = kFormatTransNc2NhOp; + nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; + auto passOp = std::make_shared(); + passOp->id = kFormatTransPassOp; + passOp->types = {PrimitiveType_QuantDTypeCast}; + auto nh2ncOp = std::make_shared(); + nh2ncOp->id = kFormatTransNh2NcOp; + nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; + + passOp->left = nh2ncOp; + nc2nhOp->left = passOp; + std::unique_ptr nh2NcAndNc2NhPassFusionPattern(new (std::nothrow) + FusionPattern(kNh2NcAndNc2NhPassFusionPattern)); + if (nh2NcAndNc2NhPassFusionPattern == nullptr) { + MS_LOG(ERROR) << "new " << kNh2NcAndNc2NhPassFusionPattern << " failed"; + return RET_ERROR; + } + nh2NcAndNc2NhPassFusionPattern->AddPatternOp(nh2ncOp); + nh2NcAndNc2NhPassFusionPattern->AddPatternOp(passOp); + nh2NcAndNc2NhPassFusionPattern->AddPatternOp(nc2nhOp); + nh2NcAndNc2NhPassFusionPattern->Finish(); + this->patterns.emplace_back(nh2NcAndNc2NhPassFusionPattern.release()); + } + return RET_OK; +} + +STATUS FormatTransFusionPass::Run(schema::MetaGraphT *graph) { return FusionPass::Run(graph); } + +STATUS FormatTransFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + if (matchedPath.size() != kFormatTransMatchPathLen2 && matchedPath.size() != kFormatTransMatchPathLen3) { + MS_LOG(ERROR) << "Format-Transform-Fusion should have " << kFormatTransMatchPathLen2 << " or " + << kFormatTransMatchPathLen3 << " NodeIndex in matchedPair"; + return RET_PARAM_INVALID; + } + + std::shared_ptr srcPath; + std::shared_ptr dstPath; + if (patternName == kNc2NhAndNh2NcFusionPattern || patternName == kNc2NhAndNh2NcPassFusionPattern) { + srcPath = matchedPath[kFormatTransNc2NhOp]; + dstPath = matchedPath[kFormatTransNh2NcOp]; + } else if (patternName == kNh2NcAndNc2NhFusionPattern || patternName == kNh2NcAndNc2NhPassFusionPattern) { + srcPath = matchedPath[kFormatTransNh2NcOp]; + dstPath = matchedPath[kFormatTransNc2NhOp]; + } else { + MS_ASSERT(false); + } + MS_ASSERT(srcPath != nullptr); + MS_ASSERT(dstPath != nullptr); + auto srcNode = graph->nodes.at(srcPath->nodeIdx).get(); + auto dstNode = graph->nodes.at(dstPath->nodeIdx).get(); + MS_ASSERT(srcNode != nullptr); + MS_ASSERT(dstNode != nullptr); + if (patternName == kNc2NhAndNh2NcFusionPattern || patternName == kNc2NhAndNh2NcPassFusionPattern) { + MS_ASSERT(GetCNodeTType(*srcNode) == schema::PrimitiveType_Nchw2Nhwc); + MS_ASSERT(GetCNodeTType(*dstNode) == schema::PrimitiveType_Nhwc2Nchw); + } else if (patternName == kNh2NcAndNc2NhFusionPattern || patternName == kNh2NcAndNc2NhPassFusionPattern) { + MS_ASSERT(GetCNodeTType(*srcNode) == schema::PrimitiveType_Nhwc2Nchw); + MS_ASSERT(GetCNodeTType(*dstNode) == schema::PrimitiveType_Nchw2Nhwc); + } else { + MS_ASSERT(false); + } + + auto status = IsolateOneWayNode(graph, srcPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << srcNode->name << ", error: " << status; + return status; + } + + status = IsolateOneWayNode(graph, dstPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name << ", error: " << status; + return status; + } + + return RET_OK; +} +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/optimizer/fusion/format_trans_fusion_pass.h b/mindspore/lite/tools/converter/optimizer/fusion/format_trans_fusion_pass.h new file mode 100644 index 00000000000..eaab52a4fdc --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/format_trans_fusion_pass.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_FORMAT_TRANS_FUSION_PASS_H +#define MINDSPORE_PREDICT_FORMAT_TRANS_FUSION_PASS_H + +#include +#include +#include +#include "tools/converter/optimizer/fusion/fusion_pass.h" + +namespace mindspore { +namespace lite { +constexpr const char *kFormatTransNc2NhOp = "FormatTransNc2NhOp"; +constexpr const char *kFormatTransNh2NcOp = "FormatTransNh2NcOp"; +constexpr const char *kFormatTransPassOp = "FormatTransPassOp"; +constexpr const char *kNc2NhAndNh2NcFusionPattern = "Nc2NhAndNh2NcFusionPattern"; +constexpr const char *kNc2NhAndNh2NcPassFusionPattern = "Nc2NhAndNh2NcPassFusionPattern"; +constexpr const char *kNh2NcAndNc2NhFusionPattern = "Nh2NcAndNc2NhFusionPattern"; +constexpr const char *kNh2NcAndNc2NhPassFusionPattern = "Nh2NcAndNc2NhPassFusionPattern"; + +class FormatTransFusionPass : public FusionPass { + public: + FormatTransFusionPass() = default; + + ~FormatTransFusionPass() override = default; + + STATUS DefinePattern() override; + + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_FORMAT_TRANS_FUSION_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/fusion/fusion_pass.cc b/mindspore/lite/tools/converter/optimizer/fusion/fusion_pass.cc new file mode 100644 index 00000000000..e7f498577ef --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/fusion_pass.cc @@ -0,0 +1,349 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tools/converter/optimizer/fusion/fusion_pass.h" +#include "utils/log_adapter.h" +#include "tools/common/converter_op_utils.h" +#include "src/common/utils.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +STATUS FusionPass::Run(schema::MetaGraphT *graph) { + auto ret = DefinePattern(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DefinePattern Error " << ret; + return ret; + } + for (auto pattern : patterns) { + if (pattern == nullptr) { + MS_LOG(ERROR) << "FusionPattern has not been set"; + return RET_PARAM_INVALID; + } + + if (!pattern->Check()) { + MS_LOG(ERROR) << "FusionPattern is invaild"; + return RET_PARAM_INVALID; + } + } + + ret = MatchPatterns(graph); + if (ret != RET_OK) { + MS_LOG(ERROR) << "MatchPattern Error " << ret; + return ret; + } + + if (this->matchedPaths.empty()) { + return RET_NO_CHANGE; + } else { + ret = Fuse(graph); + if (ret != RET_OK && ret != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Fuse Error " << ret; + } + return ret; + } +} + +STATUS FusionPass::MatchPatterns(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + this->matchedPaths.clear(); + STATUS status; + for (auto pattern : patterns) { + status = MatchOnePattern(graph, pattern); + if (status != RET_OK) { + MS_LOG(ERROR) << "MatchOnePatternInSubGraph failed: " << status; + return status; + } + } + this->mapedMatchedPaths.clear(); + for (auto iter = matchedPaths.begin(); iter != matchedPaths.end(); iter++) { + auto patternName = iter->first; + auto patternOps = iter->second; + std::vector>> mapedPaths; + for (const auto &patternOp : patternOps) { + std::queue> opQueue; + std::unordered_map> mapedPath; + opQueue.push(patternOp); + while (!opQueue.empty()) { + auto curPatternOp = opQueue.front(); + opQueue.pop(); + MS_ASSERT(curPatternOp != nullptr); + mapedPath.insert(std::make_pair(curPatternOp->id, curPatternOp->path)); + if (curPatternOp->left != nullptr) { + opQueue.push(curPatternOp->left); + } + if (curPatternOp->right != nullptr) { + opQueue.push(curPatternOp->right); + } + } + mapedPaths.emplace_back(mapedPath); + } + this->mapedMatchedPaths.insert(std::make_pair(patternName, mapedPaths)); + } + return RET_OK; +} + +// assume that all nodes have only one output. if node has multi-outputs, +// some errors may happen +STATUS FusionPass::MatchOnePattern(schema::MetaGraphT *graph, FusionPattern *pattern) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(pattern != nullptr); + // std::vector> patternMatchPaths; + auto outputOp = pattern->GetPatternOp(pattern->GetOutput()); + if (outputOp == nullptr) { + MS_LOG(ERROR) << "Can not find the output of the pattern"; + return RET_NULL_PTR; + } + MS_ASSERT(outputOp->isTail); + if (graph->nodes.empty()) { + return RET_OK; + } + // find all matched entries + std::vector entries; + std::queue nodeQueue; + std::vector sinkIdes; + for (auto index : graph->outputIndex) { + auto subGraphOutputNodeIdxes = GetLinkedPreIdx(*graph, index); + for (auto subGraphOutputNodeIdx : subGraphOutputNodeIdxes) { + MS_ASSERT((subGraph->nodes.size() > subGraphOutputNodeIdx)); + nodeQueue.push(subGraphOutputNodeIdx); + } + } + while (!nodeQueue.empty()) { + auto nodeIdx = nodeQueue.front(); + nodeQueue.pop(); + if (IsContain(sinkIdes, nodeIdx)) { + continue; + } + MS_ASSERT(subGraph->nodes.size() > nodeIdx); + auto &node = graph->nodes.at(nodeIdx); + sinkIdes.emplace_back(nodeIdx); + + MS_ASSERT(nullptr != node->primitive); + if (IsContain(outputOp->types, node->primitive->value.type)) { + entries.emplace_back(nodeIdx); + } + auto preNodeIdxes = GetInputNodeIdx(*graph, nodeIdx); + for (auto preNodeIdx : preNodeIdxes) { + MS_ASSERT((subGraph->nodes.size() > preNodeIdx)); + nodeQueue.push(preNodeIdx); + } + } + + // check each entry + std::vector> paths; + sinkIdes.clear(); + std::vector pathSinkIdes; + for (auto nodeIdx : entries) { + if (IsContain(sinkIdes, nodeIdx)) { + continue; + } + pathSinkIdes.clear(); + auto path = PatternOp::Copy(outputOp); + auto ret = MatchTree(graph, nodeIdx, path, sinkIdes, pathSinkIdes); + if (ret && CheckMatch(graph, path)) { + paths.emplace_back(path); + } + } + auto patternName = pattern->GetName(); + this->matchedPaths.insert(std::make_pair(patternName, paths)); + return RET_OK; +} + +bool FusionPass::CheckMatch(schema::MetaGraphT *graph, const std::shared_ptr& patternOp) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(patternOp != nullptr); + // find included nodes + std::queue> opQueue; + std::vector matchedNodeIdxes; + std::vector> inputNodes; + std::shared_ptr outputNode = nullptr; + opQueue.push(patternOp); + while (!opQueue.empty()) { + auto curPatternOp = opQueue.front(); + opQueue.pop(); + matchedNodeIdxes.push_back(curPatternOp->path->nodeIdx); + if (curPatternOp->isHead) { + inputNodes.emplace_back(curPatternOp); + } + if (curPatternOp->isTail) { + if (outputNode != nullptr && outputNode != curPatternOp) { + return false; + } + outputNode = curPatternOp; + } + if (curPatternOp->left != nullptr) { + opQueue.push(curPatternOp->left); + } + if (curPatternOp->right != nullptr) { + opQueue.push(curPatternOp->right); + } + } + // all post node of input node should be in path except input node is placeHold + for (const auto& inputNode : inputNodes) { + if (inputNode->isPlaceHold) { + continue; + } + auto inputNodePostNodeIdxes = GetOutputNodeIdx(*graph, inputNode->path->nodeIdx); + for (auto inputNodePostNodeIdx : inputNodePostNodeIdxes) { + if (!IsContain(matchedNodeIdxes, inputNodePostNodeIdx)) { + return false; + } + } + } + // all pre node of output node should be in path + auto outputNodePreNodeIdxes = GetInputNodeIdx(*graph, outputNode->path->nodeIdx); + for (auto outputNodePreNodeIdx : outputNodePreNodeIdxes) { + if (!IsContain(matchedNodeIdxes, outputNodePreNodeIdx)) { + return false; + } + } + return true; +} + +bool FusionPass::MatchTree(schema::MetaGraphT *graph, size_t nodeIdx, const std::shared_ptr &target, + std::vector &sinkIdes, std::vector &pathSinkIdes) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(nodeIdx < subGraph->nodes.size()); + auto &scope = graph->nodes.at(nodeIdx); + MS_ASSERT(scope != nullptr); + // if target(except target is marked head) is nullptr, it means the preNode + // has no left or right, but scope is not nullptr + if (target == nullptr) { + return false; + } + // if node is sinked and not in the pathSinkId, then return false + if (IsContain(sinkIdes, nodeIdx) && !IsContain(pathSinkIdes, nodeIdx)) { + return false; + } + // type not match + if (!target->isPlaceHold && !IsContain(target->types, scope->primitive->value.type)) { + return false; + } + // path is setted and not pointer to this node + if (target->pathSetted) { + MS_ASSERT(target->path != nullptr); + if (target->path->nodeIdx != nodeIdx) { + return false; + } + } + target->SetPath(-1, nodeIdx); + sinkIdes.push_back(nodeIdx); + pathSinkIdes.push_back(nodeIdx); + // target is marked head, no need to check left and right. head-target's left + // and right is always nullptr + if (target->isHead) { + return true; + } + auto preNodeIdxes = GetInputNodeIdx(*graph, nodeIdx); + if (preNodeIdxes.empty() && target->left == nullptr && target->right == nullptr) { + return true; + } + for (auto preNodeIdx : preNodeIdxes) { + MS_ASSERT(subGraph->nodes.size() > preNodeIdx); + // match left + if (MatchTree(graph, preNodeIdx, target->left, sinkIdes, pathSinkIdes)) { + // match right + if (preNodeIdxes.size() == 1 && target->right == nullptr) { + return true; + } + for (auto preNodeIdxInner : preNodeIdxes) { + if (preNodeIdxInner == preNodeIdx) { + continue; + } + MS_ASSERT(subGraph->nodes.size() > preNodeIdxInner); + if (MatchTree(graph, preNodeIdxInner, target->right, sinkIdes, pathSinkIdes)) { + return true; // ignore follow match, pick the first match + } + } + } + } + sinkIdes.erase((sinkIdes.end() - 1)); + pathSinkIdes.erase((pathSinkIdes.end() - 1)); + target->UnSetPath(); + return false; +} + +STATUS FusionPass::Fuse(schema::MetaGraphT *graph) { + STATUS ret; + bool isChange = false; + for (auto iter = mapedMatchedPaths.begin(); iter != mapedMatchedPaths.end(); iter++) { + for (auto &matchedPath : iter->second) { + ret = DoFusion(graph, iter->first, matchedPath); + if (ret != RET_OK && ret != RET_NO_CHANGE) { + MS_LOG(ERROR) << "DoFusion Error " << ret; + return ret; + } else { + if (ret == RET_OK) { + isChange = true; + } + } + } + } + return isChange ? RET_OK : RET_NO_CHANGE; +} + +FusionPass::~FusionPass() { + for (auto pattern : patterns) { + if (pattern != nullptr) { + delete (pattern); + } + } +} + +void FusionPass::MergeNodeAttrFromPost(std::unique_ptr &dstOp, std::unique_ptr &postOp, + size_t dstOpOutIdx) { + // // merge quantParam + // if (dstOp->quantParam.empty()) { // not awareing quant + // return; + // } + // MS_ASSERT(postOp->outputIndex.size() == 1); + // if (dstOp->quantParam.size() != dstOp->inputIndex.size() + dstOp->outputIndex.size()) { + // int a = 1; + // } + // MS_ASSERT(dstOp->quantParam.size() == dstOp->inputIndex.size() + dstOp->outputIndex.size()); + // auto &dstQuantParamArray = dstOp->quantParam.at(dstOp->inputIndex.size() + dstOpOutIdx); + // auto &postQuantParamArray = postOp->quantParam.back(); + // if (!(postQuantParamArray != nullptr && postQuantParamArray->param.size() == 1 && + // postQuantParamArray->param.front() != nullptr && postQuantParamArray->param.front()->min != FLT_MAX)) { + // return; // postNode has no quantParam, no need merge + // } + // + // if ((dstQuantParamArray != nullptr && dstQuantParamArray->param.size() != 1) || + // (dstQuantParamArray->param.front() != nullptr && dstQuantParamArray->param.front()->min != FLT_MAX)) { + // return; // dstNode has quantParam, no need merge + // } + // + // dstQuantParamArray->param.front()->min = postQuantParamArray->param.front()->min; + // dstQuantParamArray->param.front()->max = postQuantParamArray->param.front()->max; + // dstQuantParamArray->param.front()->scale = postQuantParamArray->param.front()->scale; + // dstQuantParamArray->param.front()->zeroPoint = postQuantParamArray->param.front()->zeroPoint; + // MS_LOGD("merge quantParam from %s to %s", postOp->name.c_str(), dstOp->name.c_str()); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/fusion/fusion_pass.h b/mindspore/lite/tools/converter/optimizer/fusion/fusion_pass.h new file mode 100644 index 00000000000..b7de4cc8307 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/fusion_pass.h @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_FUSION_PASS_H +#define MINDSPORE_PREDICT_FUSION_PASS_H + +#include +#include +#include +#include +#include +#include "tools/common/converter_op_utils.h" +#include "tools/converter/optimizer.h" +#include "tools/converter/optimizer/fusion/fusion_pattern.h" + +namespace mindspore { +namespace lite { +#define CONV_OP_NO_BIAS_WEIGHT_NUM 1 +#define CONV_OP_HAS_BIAS_WEIGHT_NUM 2 +#define CONV_OP_NO_BIAS_INPUT_NUM 2 +#define CONV_OP_HAS_BIAS_INPUT_NUM 3 + +#define CONV_OP_FILTER_INDEX_IN_WEIGHT 0 +#define CONV_OP_BIAS_INDEX_IN_WEIGHT 1 +#define CONV_OP_FILTER_INDEX_IN_INPUT 1 +#define CONV_OP_BIAS_INDEX_IN_INPUT 2 + +#define CONV_FILTER_SHAPE_SIZE 4 + +// PatternOp Ids +constexpr const char *kConvName = "CONVOLUTION"; +constexpr const char *DST_NAME = "DESTINATION"; +constexpr const char *ACTIVATION_NAME = "ACTIVATION"; +constexpr const char *BIASADD_NAME = "BIASADD"; + +class FusionPass : public GraphPass { + public: + FusionPass() = default; + + ~FusionPass() override; + + virtual STATUS DefinePattern() = 0; + + STATUS Run(schema::MetaGraphT *graph) override; + + virtual STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) = 0; + + protected: + STATUS MatchPatterns(schema::MetaGraphT *graph); + + STATUS MatchOnePattern(schema::MetaGraphT *graph, FusionPattern *pattern); + + bool MatchTree(schema::MetaGraphT *graph, size_t nodeIdx, const std::shared_ptr &target, + std::vector &sinkIdes, std::vector &pathSinkIdes); + + static bool CheckMatch(schema::MetaGraphT *graph, const std::shared_ptr& patternOp); + + void MergeNodeAttrFromPost(std::unique_ptr &dstOp, std::unique_ptr &postOp, + size_t dstOpOutIdx = 0); + + STATUS Fuse(schema::MetaGraphT *graph); + + protected: + std::vector patterns; + std::map>> matchedPaths; + // {name of pattern, vector<{name of pattern node, path}>} + std::map>>> mapedMatchedPaths; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_FUSION_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/fusion/fusion_pattern.cc b/mindspore/lite/tools/converter/optimizer/fusion/fusion_pattern.cc new file mode 100644 index 00000000000..0b23502ad24 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/fusion_pattern.cc @@ -0,0 +1,182 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "tools/converter/optimizer/fusion/fusion_pattern.h" +#include "utils/log_adapter.h" +#include "src/common/utils.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +// using namespace std; + +FusionPattern::FusionPattern(std::string name) { this->name = std::move(name); } + +FusionPattern::~FusionPattern() = default; + +FusionPattern &FusionPattern::SetName(const std::string &name) { + this->name = name; + return *this; +} + +FusionPattern &FusionPattern::AddPatternOp(const std::string &id, + const std::initializer_list &types) { + return AddPatternOp(id, std::vector(types)); +} + +FusionPattern &FusionPattern::AddPatternOp(const std::string &id, const std::vector &types) { + if (id.empty()) { + // MS_LOG(ERROR) << "Id cannot be empty"); + hasError = true; + } + + if (GetPatternOp(id) != nullptr) { + // MS_LOG(ERROR) << "Id repeated. (id:%s)", id.c_str()); + hasError = true; + } + + std::shared_ptr op(new PatternOp()); + if (op == nullptr) { + // MS_LOG(ERROR) << "new an object failed"); + hasError = true; + } + + op->id = id; + op->types = types; + ops.push_back(op); + opMap[id] = op; + + return *this; +} + +FusionPattern &FusionPattern::RemovePatternOp(const std::string &id) { + for (uint32_t loop = 0; loop < ops.size(); loop++) { + std::shared_ptr op = ops.at(loop); + if (op->id == id) { + ops.erase(ops.begin() + loop); + opMap.erase(id); + break; + } + } + return *this; +} + +bool FusionPattern::Check() { + if (hasError) { + // MS_LOG(ERROR) << "Has Error in previous Func"); + return false; + } + + if (GetPatternOp(this->outputOpId) == nullptr) { + // MS_LOG(ERROR) << "Can not find the output of the pattern"); + return false; + } + + return true; +} + +void FusionPattern::Dump() const { + std::ostringstream oss; + oss << std::endl << "Pattern " << name << std::endl; + for (const auto op : ops) { + oss << " " << op->id << ": {"; + for (auto &type : op->types) { + oss << schema::EnumNamePrimitiveType(type) << ", "; + } + oss << "} {"; + if (op->left != nullptr) { + oss << "leftPreNode: " << op->left->id << ", "; + } + if (op->right != nullptr) { + oss << "rightPreNode: " << op->right->id << ", "; + } + oss << "}"; + + oss << std::endl; + } +} + +std::shared_ptr FusionPattern::GetPatternOp(const std::string &id) const { + auto it = opMap.find(id); + if (it != opMap.end()) return it->second; + + return nullptr; +} + +std::string FusionPattern::GetOutput() const { return this->outputOpId; } + +FusionPattern &FusionPattern::AddPatternOp(const std::shared_ptr &patternOp) { + ops.push_back(patternOp); + opMap[patternOp->id] = patternOp; + return *this; +} + +FusionPattern &FusionPattern::Finish() { + std::vector ids; + std::set nodeInputIds; + std::vector inputNodeIds; + for (auto patternOp : ops) { + if (IsContain(ids, patternOp->id)) { + // MS_LOG(ERROR) << "Duplicate id find: %s", patternOp->id.c_str()); + hasError = true; + return *this; + } + ids.emplace_back(patternOp->id); + if (patternOp->left != nullptr) { + nodeInputIds.insert(patternOp->left->id); + } + if (patternOp->right != nullptr) { + nodeInputIds.insert(patternOp->right->id); + } + if (patternOp->left == nullptr && patternOp->right == nullptr) { + inputNodeIds.emplace_back(patternOp->id); + } + } + for (auto iter = ids.begin(); iter != ids.end();) { + if (nodeInputIds.find(*iter) != nodeInputIds.end()) { + iter = ids.erase(iter); + } else { + iter++; + } + } + if (ids.size() > 1) { + // MS_LOG(ERROR) << "Multi-output node find, only support pattern with one output"); + hasError = true; + return *this; + } + if (ids.empty()) { + // MS_LOG(ERROR) << "No output node find, only support pattern with one output"); + hasError = true; + return *this; + } + this->outputOpId = ids.front(); + auto outputNode = GetPatternOp(this->outputOpId); + MS_ASSERT(outputNode != nullptr); + outputNode->isTail = true; + + for (auto inputNodeId : inputNodeIds) { + auto inputNode = GetPatternOp(inputNodeId); + MS_ASSERT(inputNode != nullptr); + inputNode->isHead = true; + } + return *this; +} + +std::string FusionPattern::GetName() { return this->name; } +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/optimizer/fusion/fusion_pattern.h b/mindspore/lite/tools/converter/optimizer/fusion/fusion_pattern.h new file mode 100644 index 00000000000..334185e7367 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/fusion_pattern.h @@ -0,0 +1,141 @@ +#include + +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_FUSION_PATTERN_H +#define MINDSPORE_PREDICT_FUSION_PATTERN_H + +#include +#include +#include +#include +// #include +#include "utils/log_adapter.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +struct Path { + public: + Path(int32_t subGraphIdx, int32_t nodeIdx) : subGraphIdx(subGraphIdx), nodeIdx(nodeIdx) {} + int32_t subGraphIdx = -1; + int32_t nodeIdx = -1; +}; + +// Op description in pattern +struct PatternOp { + std::string id; // id of op in pattern + std::vector types; // type of matchable op + // TODO(...): only support node with no more than two preNode now + // avoid loop reference + std::shared_ptr left; // left input patternOp of this patternOp + std::shared_ptr right; // right input patternOp of this patternOp + std::shared_ptr path = std::make_shared(-1, -1); + bool pathSetted = false; + bool isHead = false; + bool isTail = false; + bool isPlaceHold = false; + + PatternOp() = default; + explicit PatternOp(const std::string &inId) : id(inId) {} + ~PatternOp() = default; + void SetPath(size_t subGraphIdx, size_t nodeIdx) { + MS_ASSERT(this->path != nullptr); + this->path->subGraphIdx = subGraphIdx; + this->path->nodeIdx = nodeIdx; + this->pathSetted = true; + } + void UnSetPath() { + MS_ASSERT(this->path != nullptr); + this->path->subGraphIdx = -1; + this->path->nodeIdx = -1; + this->pathSetted = false; + } + static std::shared_ptr Copy(const std::shared_ptr& src) { + if (src == nullptr) { + return nullptr; + } + auto dst = std::make_shared(); + dst->id = src->id; + dst->types = src->types; + if (src->path != nullptr) { + dst->path = std::make_shared(src->path->subGraphIdx, src->path->nodeIdx); + } + dst->pathSetted = src->pathSetted; + dst->isTail = src->isTail; + dst->isHead = src->isHead; + dst->isPlaceHold = src->isPlaceHold; + dst->left = PatternOp::Copy(src->left); + dst->right = PatternOp::Copy(src->right); + return dst; + } +}; + +class FusionPattern { + public: + explicit FusionPattern(std::string name = ""); + + ~FusionPattern(); + + std::string GetName(); + + FusionPattern &SetName(const std::string &name); + + FusionPattern &AddPatternOp(const std::string &id, const std::initializer_list &types = {}); + + FusionPattern &AddPatternOp(const std::string &id, const std::vector &types); + + FusionPattern &AddPatternOp(const std::shared_ptr& patternOp); + + FusionPattern &RemovePatternOp(const std::string &id); + + // set id of patternOp + // set isTail and isHead for patternOps + FusionPattern &Finish(); + + bool Check(); + // get the id of the output Op of th pattern + std::string GetOutput() const; + + void Dump() const; + + // return nullptr if not find + std::shared_ptr GetPatternOp(const std::string &id) const; + + private: + FusionPattern(const FusionPattern &) = default; + + FusionPattern &operator=(const FusionPattern &) = default; + + private: + std::string name; + + std::vector> ops; + + // same with ops, just for search + std::map> opMap; + + // output PatternOp id of pattern + std::string outputOpId; + + bool hasError = false; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_FUSION_PATTERN_H + diff --git a/mindspore/lite/tools/converter/optimizer/fusion/matmul_biasadd_fusion_pass.cc b/mindspore/lite/tools/converter/optimizer/fusion/matmul_biasadd_fusion_pass.cc new file mode 100644 index 00000000000..ff461a2a630 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/matmul_biasadd_fusion_pass.cc @@ -0,0 +1,225 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "tools/converter/optimizer/fusion/matmul_biasadd_fusion_pass.h" +#include "utils/log_adapter.h" +#include "securec/include/securec.h" +// #include "utils/log_adapter.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" +#include "src/common/op_utils.h" + +namespace mindspore { +namespace lite { +#define MATMUL_BIASADD_MATCH_PATH_LEN 2 +#define BIASADD_OP_BIAS_INDEX 1 +#define BIASADD_OP_INPUT_NUM 2 + +STATUS MatMulBiasAddFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); } + +STATUS MatMulBiasAddFusionPass::DefinePattern() { + auto matMulOp = std::make_shared(); + matMulOp->id = MATMUL_NAME; + matMulOp->types = {schema::PrimitiveType_MatMul}; + auto baOp = std::make_shared(); + baOp->id = BIASADD_NAME; + baOp->types = {schema::PrimitiveType_BiasAdd}; + baOp->left = matMulOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern("MatMulBiasAddFusion")); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failed"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(matMulOp); + fusionPattern->AddPatternOp(baOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + + return RET_OK; +} + +STATUS MatMulBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + if (matchedPath.size() != MATMUL_BIASADD_MATCH_PATH_LEN) { + MS_LOG(ERROR) << "MatMul-BiasAdd-Fusion should have two NodeIndex in matchedPair"; + return RET_PARAM_INVALID; + } + + auto matMulPath = matchedPath[MATMUL_NAME]; + auto baPath = matchedPath[BIASADD_NAME]; + auto &matMulNode = graph->nodes.at(matMulPath->nodeIdx); + auto &baNode = graph->nodes.at(baPath->nodeIdx); + // can not check shape because there is now shape infer in converter + MS_ASSERT(matMulNode != nullptr); + MS_ASSERT(matMulNode->inputIndex.size() == 2); + // biasadd node the second tensor is not constant tensor, don't fusion + auto baNodeInputIndex = baNode->inputIndex; + if (baNodeInputIndex.size() != BIASADD_OP_INPUT_NUM) { + MS_LOG(ERROR) << "%s node tensors number is invalid! "; // baNode->name.c_str()); + return RET_ERROR; + } + MS_ASSERT(graph->allTensors.size() > baNodeInputIndex.at(BIASADD_OP_BIAS_INDEX)); + const auto &baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex.at(BIASADD_OP_BIAS_INDEX)); + MS_ASSERT(baNodeBiasTensor != nullptr); + if (baNodeBiasTensor->refCount != schema::NodeType_ValueNode) { + // dont fusion, return + return RET_OK; + } + + // 1. add biasTensor for matMul + auto status = AddFullConnectionBiasTensor(matMulPath, baPath, graph); + if (RET_OK != status) { + MS_LOG(ERROR) << "AddFullConnectionBiasTensor failed, %d"; // status); + return status; + } + + // 2. change matmul to full connection op + matMulNode->name += "-fc"; + std::unique_ptr fcAttr(new FullConnectionT()); + if (fcAttr == nullptr) { + MS_LOG(ERROR) << "new FullConnectionT node failed"; + return RET_ERROR; + } + fcAttr->hasBias = true; + fcAttr->axis = 1; + MS_ASSERT(matMulNode->primitive->value.AsMatMul() != nullptr); + transA = matMulNode->primitive->value.AsMatMul()->transposeA; + transB = matMulNode->primitive->value.AsMatMul()->transposeB; + MS_ASSERT(matMulNode->primitive->value.value != nullptr); + delete (matMulNode->primitive->value.value); + matMulNode->primitive->value.type = schema::PrimitiveType_FullConnection; + matMulNode->primitive->value.value = fcAttr.release(); + + // 3. delete BiasAdd node + MergeNodeAttrFromPost(matMulNode, baNode); + status = IsolateOneWayNode(graph, baPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, subGraph: %zu, node: %zu, error: %d"; + // baPath->subGraphIdx, baPath->nodeIdx, status); + return status; + } + + // 4. addTranspose node + status = InsertTransposeNode(graph, matMulPath); + if (status != RET_OK) { + MS_LOG(ERROR) + << "InsertTransposeNode failed, subGraph: %zu, node: %zu, error: %d"; // matMulPath->subGraphIdx, + // matMulPath->nodeIdx, status); + return status; + } + return RET_OK; +} + +STATUS MatMulBiasAddFusionPass::InsertTransposeNode(MetaGraphT *graph, const std::shared_ptr &matMulPath) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(matMulPath != nullptr); + auto &matMulNode = graph->nodes.at(matMulPath->nodeIdx); + MS_ASSERT(graph->allTensors.size() > matMulNode->inputIndex.at(0)); + MS_ASSERT(graph->allTensors.size() > matMulNode->inputIndex.at(2)); + const auto &tensorA = graph->allTensors.at(matMulNode->inputIndex.at(0)); + const auto &tensorB = graph->allTensors.at(matMulNode->inputIndex.at(1)); + + std::vector insertNodeIdxList; + if (transA) { + insertNodeIdxList.emplace_back(0); + } + if (!transB) { + insertNodeIdxList.emplace_back(1); + } + + auto matmulOpIter = graph->nodes.begin() + matMulPath->nodeIdx; + STATUS errorCode = RET_OK; + for (auto needInsertIdx : insertNodeIdxList) { + auto transNode = std::unique_ptr(new (std::nothrow) CNodeT); + if (transNode == nullptr) { + MS_LOG(ERROR) << "new TransNode failed"; + return RET_ERROR; + } + transNode->name = "transpose" + std::to_string(id++); + transNode->primitive->value.type = schema::PrimitiveType_Transpose; + std::unique_ptr transposeParam(new TransposeT()); + if (transposeParam == nullptr) { + MS_LOG(ERROR) << "new transposeParam failed"; + return RET_ERROR; + } + transposeParam->conjugate = false; + transposeParam->perm = {1, 0}; + transNode->primitive->value.value = transposeParam.release(); + matmulOpIter = + InsertNode(graph, matmulOpIter, kBefore, needInsertIdx, std::move(transNode), &errorCode, TransposeOpCopyer); + if (errorCode != RET_OK) { + MS_LOG(ERROR) << "InsertNode failed: %d"; // errorCode); + return errorCode; + } + } + return RET_OK; +} + +#define BIASADD_WEIGHT_SHAPE_SIZE 1 +#define BIASADD_BIAS_DIM_INDEX 0 + +STATUS MatMulBiasAddFusionPass::AddFullConnectionBiasTensor(const std::shared_ptr &matMulPath, + const std::shared_ptr &baPath, MetaGraphT *graph) { + MS_ASSERT(matMulPath != nullptr); + MS_ASSERT(baPath != nullptr); + MS_ASSERT(graph != nullptr); + + MS_ASSERT(graph->nodes.size() > matMulPath->nodeIdx); + auto &matMulNode = graph->nodes.at(matMulPath->nodeIdx); + MS_ASSERT(matMulNode != nullptr); + auto baNode = graph->nodes.at(baPath->nodeIdx).get(); + MS_ASSERT(baNode != nullptr); + + // check biasTensor + auto baWeightTensorIdxes = baNode->inputIndex; + if (baWeightTensorIdxes.size() != BIASADD_OP_INPUT_NUM) { + MS_LOG(ERROR) << "%s node tensors number is invalid! "; // baNode->name.c_str()); + return RET_ERROR; + } + MS_ASSERT(graph->allTensors.size() > baWeightTensorIdxes.at(BIASADD_OP_BIAS_INDEX)); + auto &biasTensor = graph->allTensors.at(baWeightTensorIdxes.at(BIASADD_OP_BIAS_INDEX)); + MS_ASSERT(biasTensor != nullptr); + auto biasDims = biasTensor->dims; + // if biasTensor is a scaler + if (biasDims.empty() && biasTensor->data.data() == nullptr) { + MS_LOG(ERROR) << "BiasAdd node %s bias tensor is invalid"; // baNode->name.c_str()); + return RET_ERROR; + } + if (!biasDims.empty() && biasDims.size() != BIASADD_WEIGHT_SHAPE_SIZE) { + MS_LOG(ERROR) + << "BiasAdd bias tensor should has one dimension, current number of dimension %zu. or bias tensor is a scaler"; + // biasDims.size()); + return RET_ERROR; + } + // add biasTensor to matmul + matMulNode->inputIndex.emplace_back(baWeightTensorIdxes.at(BIASADD_OP_BIAS_INDEX)); + baNode->inputIndex.erase(baNode->inputIndex.begin() + BIASADD_OP_BIAS_INDEX); + + return RET_OK; +} + +MatMulBiasAddFusionPass::~MatMulBiasAddFusionPass() = default; +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/optimizer/fusion/matmul_biasadd_fusion_pass.h b/mindspore/lite/tools/converter/optimizer/fusion/matmul_biasadd_fusion_pass.h new file mode 100644 index 00000000000..4edc1eca8d1 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/matmul_biasadd_fusion_pass.h @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H +#define MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H + +#include +#include +#include +#include +#include +#include "tools/converter/optimizer/fusion/fusion_pass.h" +#include "tools/common/graph_util.h" + +namespace mindspore { +namespace lite { +constexpr const char *MATMUL_NAME = "MATMUL"; + +class MatMulBiasAddFusionPass : public FusionPass { + public: + MatMulBiasAddFusionPass() = default; + + ~MatMulBiasAddFusionPass() override; + + STATUS DefinePattern() override; + + STATUS DoFusion(MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(MetaGraphT *graph) override; + + protected: + static STATUS AddFullConnectionBiasTensor(const std::shared_ptr& matMulPath, + const std::shared_ptr& dstPath, + MetaGraphT *subGraph); + STATUS InsertTransposeNode(MetaGraphT *subGraph, const std::shared_ptr& matMulPath); + + protected: + bool transA = false; + bool transB = false; + size_t id = 0; + + OpDefCopyer TransposeOpCopyer = [](const std::unique_ptr &inOpDef) -> std::unique_ptr { + std::unique_ptr newOpDef(new (std::nothrow) CNodeT); + if (newOpDef == nullptr) { + MS_LOG(ERROR) << "new OpDefT failed"; + return nullptr; + } + newOpDef->name = inOpDef->name; + newOpDef->quantType = inOpDef->quantType; + newOpDef->primitive->value.type = schema::PrimitiveType_Transpose; + auto transposeParam = new (std::nothrow) TransposeT; + if (transposeParam == nullptr) { + MS_LOG(ERROR) << "new transposeParam failed"; + return nullptr; + } + auto inParam = inOpDef->primitive->value.AsTranspose(); + MS_ASSERT(inParam != nullptr); + transposeParam->conjugate = inParam->conjugate; + transposeParam->perm.resize(inParam->perm.size()); + std::transform(inParam->perm.begin(), inParam->perm.end(), transposeParam->perm.begin(), + [](const int32_t ele) { return ele; }); + newOpDef->primitive->value.value = transposeParam; + return std::move(newOpDef); + }; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_MATMUL_BIASADD_FUSION_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/fusion/quant_cast_fusion_pass.cc b/mindspore/lite/tools/converter/optimizer/fusion/quant_cast_fusion_pass.cc new file mode 100644 index 00000000000..8b760af808d --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/quant_cast_fusion_pass.cc @@ -0,0 +1,139 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/optimizer/fusion/quant_cast_fusion_pass.h" +#include "utils/log_adapter.h" +#include "securec/include/securec.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +#define kQuantCastMatchPathLen2 2 +#define kQuantCastMatchPathLen3 3 + +STATUS QuantCastFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); } + +STATUS QuantCastFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + if (matchedPath.size() != kQuantCastMatchPathLen2 && matchedPath.size() != kQuantCastMatchPathLen3) { + MS_LOG(ERROR) << "QuantDtypeCastFusion should have " << kQuantCastMatchPathLen2 << " or " << + kQuantCastMatchPathLen3 << " NodeIndex in matchedPair"; + return RET_PARAM_INVALID; + } + + auto srcPath = matchedPath[kQuantCastSrcOp]; + MS_ASSERT(srcPath != nullptr); + auto dstPath = matchedPath[kQuantCastDstOp]; + MS_ASSERT(dstPath != nullptr); + auto srcNode = graph->nodes.at(srcPath->nodeIdx).get(); + MS_ASSERT(srcNode != nullptr); + auto dstNode = graph->nodes.at(dstPath->nodeIdx).get(); + MS_ASSERT(dstNode != nullptr); + + // todo check + if (srcNode->inputIndex.empty() && srcNode->outputIndex.empty()) { + MS_LOG(DEBUG) << "srcNode " << srcNode->name.c_str() << " has been removed"; + return RET_NO_CHANGE; + } + if (dstNode->inputIndex.empty() && dstNode->outputIndex.empty()) { + MS_LOG(DEBUG) << "dstNode " << dstNode->name.c_str() << " has been removed"; + return RET_NO_CHANGE; + } + + auto srcAttr = srcNode->primitive->value.AsQuantDTypeCast(); + auto dstAttr = dstNode->primitive->value.AsQuantDTypeCast(); + MS_ASSERT(srcAttr != nullptr); + MS_ASSERT(dstAttr != nullptr); + if (srcAttr->dstT != dstAttr->srcT || srcAttr->srcT != dstAttr->dstT) { + MS_LOG(ERROR) << "srcNode and dstNode can not been fused"; + return RET_ERROR; + } + + auto status = IsolateOneWayNode(graph, srcPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << srcNode->name.c_str() << ", error: " << status; + return status; + } + + status = IsolateOneWayNode(graph, dstPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name.c_str() << ", error: " << status; + return status; + } + + return RET_OK; +} + +STATUS QuantCastFusionPass::DefinePattern() { + // quantCast + quantCast + { + auto srcOp = std::make_shared(); + srcOp->id = kQuantCastSrcOp; + srcOp->types = {schema::PrimitiveType_QuantDTypeCast}; + auto dstOp = std::make_shared(); + dstOp->id = kQuantCastDstOp; + dstOp->types = {schema::PrimitiveType_QuantDTypeCast}; + dstOp->left = srcOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern(kQuantCastFusionPattern)); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failde"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(srcOp); + fusionPattern->AddPatternOp(dstOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + } + // quantCast + formatTrans + quantCast + { + auto srcOp = std::make_shared(); + srcOp->id = kQuantCastSrcOp; + srcOp->types = {schema::PrimitiveType_QuantDTypeCast}; + auto formatOp = std::make_shared(); + formatOp->id = kFormatTransOp; + formatOp->types = {schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Nchw2Nhwc}; + formatOp->left = srcOp; + auto dstOp = std::make_shared(); + dstOp->id = kQuantCastDstOp; + dstOp->types = {schema::PrimitiveType_QuantDTypeCast}; + dstOp->left = formatOp; + + std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern(kQuantCastPassFusionPattern)); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failde"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(srcOp); + fusionPattern->AddPatternOp(formatOp); + fusionPattern->AddPatternOp(dstOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/optimizer/fusion/quant_cast_fusion_pass.h b/mindspore/lite/tools/converter/optimizer/fusion/quant_cast_fusion_pass.h new file mode 100644 index 00000000000..1600fbc7027 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/fusion/quant_cast_fusion_pass.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_QUANT_CAST_FUSION_PASS_H +#define MINDSPORE_PREDICT_QUANT_CAST_FUSION_PASS_H + +#include +#include +#include +#include "tools/converter/optimizer/fusion/fusion_pass.h" + +namespace mindspore { +namespace lite { +constexpr const char *kQuantCastSrcOp = "QuantCastSrcOp"; +constexpr const char *kFormatTransOp = "FormatTransOp"; +constexpr const char *kQuantCastDstOp = "QuantCastDstOp"; + +constexpr const char *kQuantCastFusionPattern = "QuantCastFusionPattern"; +constexpr const char *kQuantCastPassFusionPattern = "QuantCastPassFusionPattern"; + +class QuantCastFusionPass : public FusionPass { + public: + QuantCastFusionPass() = default; + + ~QuantCastFusionPass() override = default; + + STATUS DefinePattern() override; + + STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_QUANT_CAST_FUSION_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/graph/CMakeLists.txt b/mindspore/lite/tools/converter/optimizer/graph/CMakeLists.txt new file mode 100755 index 00000000000..e5d2ceac19f --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/graph/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(graph_pass_mid OBJECT + ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/unused_node_remove_pass.cc + ) diff --git a/mindspore/lite/tools/converter/optimizer/graph/format_trans_pass.cc b/mindspore/lite/tools/converter/optimizer/graph/format_trans_pass.cc new file mode 100644 index 00000000000..ceb17f0a0b9 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/graph/format_trans_pass.cc @@ -0,0 +1,200 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/optimizer/graph/format_trans_pass.h" +#include "tools/common/converter_op_utils.h" +#include "tools/common/node_util.h" +#include "utils/log_adapter.h" +#include "src/common/common.h" +#include "src/common/utils.h" + +namespace mindspore { +namespace lite { +#define kMinInputNum 1 +#define kOutputNum 1 + +STATUS FormatTransPass::Run(schema::MetaGraphT *graph) { + if (fmkType == converter::FmkType_TF) { + return RET_OK; + } + MS_ASSERT(graph != nullptr); + auto status = DoModelInputFormatTrans(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoModelInputFormatTrans failed : " << status; + return status; + } + status = DoNodeInoutFormatTrans(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoNodeInoutFormatTrans failed : " << status; + return status; + } + return RET_OK; +} + +STATUS FormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) { + if (fmkType == converter::FmkType_TF || fmkType == converter::FmkType_TFLITE) { + return RET_OK; + } + MS_ASSERT(graph != nullptr); + // insert trans node in model input tensor + if (graph->nodes.empty()) { + return RET_OK; + } + auto graphInputIdxes = graph->inputIndex; + for (size_t i = 0; i < graphInputIdxes.size(); i++) { + auto inputIdx = graphInputIdxes.at(i); + MS_ASSERT(inputIdx < subGraph->allTensors.size()); + auto &tensor = graph->allTensors.at(inputIdx); + if (tensor->dims.size() != kNCHWDimNumber) { + continue; + } + + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { + auto &node = *iter; + for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) { + if (node->inputIndex.at(inputIndexIdx) == inputIdx) { + STATUS status = RET_OK; + iter = InsertFormatTransNode(graph, iter, kBefore, inputIndexIdx, kNHWC2NCHW, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << (*iter)->name << " failed"; + return status; + } + // set first tensor format to nhwc + auto &transNode = *(iter - 1); + MS_ASSERT(transNode != nullptr); + MS_ASSERT(transNode->inputIndex.size() == 1); + MS_ASSERT(subGraph->allTensors.size() > transNode->inputIndex.front()); + auto &graphInTensor = graph->allTensors.at(transNode->inputIndex.front()); + graphInTensor->format = schema::Format_NHWC; + // assume parser not reformat shape + auto oldDims = graphInTensor->dims; + graphInTensor->dims = {oldDims[NCHW_N], oldDims[NCHW_H], oldDims[NCHW_W], oldDims[NCHW_C]}; + break; + } + } + } + } + return RET_OK; +} + +// inference needed inputFormat: +// conv deconv depth dedepth +// fp32 NCHW NCHW NCHW NCHW +// uint8 NCHW ? NCHW ? +STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + // insert before and after the op cal by nchw/nc4hw4 + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { + FormatTransNodeType beforeNodeType, afterNodeType; + if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc + // if (quantType == QuantType_AwareTrainning) { // awaretrainning op use + // nhwc + // if (IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only + // support nhwc + // continue; + // } + // if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { + // continue; + // } + // } else { + // if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { + continue; + // } + // } + // beforeNodeType = kNCHW2NHWC; + // afterNodeType = kNHWC2NCHW; + } else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw + // if (quantType == QuantType_AwareTrainning) { // awaretrainning op use nhwc + // if (!IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only support nhwc + // continue; + // } + // } else { + // continue; + // } + if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { + continue; + } + beforeNodeType = kNCHW2NHWC; + afterNodeType = kNHWC2NCHW; + } else if (fmkType == converter::FmkType_MS) { + if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { + continue; + } + beforeNodeType = kNCHW2NHWC; + afterNodeType = kNHWC2NCHW; + } else { + MS_LOG(ERROR) << "Unsupported fmk: " << fmkType; + return RET_ERROR; + } + auto &node = *iter; + auto nodeName = node->name; + if (node->inputIndex.size() < kMinInputNum) { + MS_LOG(ERROR) << "Op should have " << kMinInputNum << " input tensor at least"; + return RET_ERROR; + } + if (node->outputIndex.size() != kOutputNum) { + MS_LOG(ERROR) << "Op should have " << kOutputNum << " output tensor"; + return RET_ERROR; + } + STATUS status; + iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << nodeName << "failed"; + return RET_ERROR; + } + + iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; + return RET_ERROR; + } + } + return RET_OK; +} + +NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, + size_t inoutIdx, FormatTransNodeType nodeType, STATUS *errorCode) { + MS_ASSERT((*existNodeIter) != nullptr); + auto existNodeName = (*existNodeIter)->name; + std::string tileName; + if (place == kBefore) { + tileName = existNodeName + "_pre"; + } else { + tileName = existNodeName + "_post"; + } + auto transNode = std::make_unique(); + transNode->primitive = std::make_unique(); + + if (nodeType == kNCHW2NHWC) { + transNode->name = "nchw2nhwc_" + tileName + std::to_string(id++); + transNode->primitive->value.type = schema::PrimitiveType_Nchw2Nhwc; + } else { + transNode->name = "nhwc2nchw_" + tileName + std::to_string(id++); + transNode->primitive->value.type = schema::PrimitiveType_Nhwc2Nchw; + } + return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode); +} + +// void FormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } + +void FormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkType; } + +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/graph/format_trans_pass.h b/mindspore/lite/tools/converter/optimizer/graph/format_trans_pass.h new file mode 100644 index 00000000000..c2ce7eb49b2 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/graph/format_trans_pass.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H +#define MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H + +#include "tools/converter/optimizer.h" +#include "tools/common/graph_util.h" +#include "tools/converter/converter_flags.h" + +namespace mindspore { +namespace lite { +enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW }; + +class FormatTransPass : public GraphPass { + public: + FormatTransPass() : id(0) {} + + ~FormatTransPass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; + + // void SetQuantType(QuantType quantType); + + void SetFmk(converter::FmkType fmkType); + + private: + STATUS DoModelInputFormatTrans(schema::MetaGraphT *graph); + + STATUS DoNodeInoutFormatTrans(schema::MetaGraphT *graph); + + NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, + FormatTransNodeType nodeType, STATUS *errorCode); + + private: + size_t id; + QuantType quantType = QuantType_QUANT_NONE; + converter::FmkType fmkType = converter::FmkType_TF; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/graph/isolated_node_remove_pass.cc b/mindspore/lite/tools/converter/optimizer/graph/isolated_node_remove_pass.cc new file mode 100644 index 00000000000..0c2d2c0d3fb --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/graph/isolated_node_remove_pass.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "tools/converter/optimizer/graph/isolated_node_remove_pass.h" +#include "utils/log_adapter.h" +#include "tools/common/converter_op_utils.h" +#include "src/common/utils.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +STATUS IsolatedNodeRemovePass::Run(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + bool ifChanged = false; + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end();) { + if ((*iter)->inputIndex.empty() && (*iter)->outputIndex.empty()) { + ifChanged = true; + iter = graph->nodes.erase(iter); + } else { + iter++; + } + } + return ifChanged ? RET_OK : RET_NO_CHANGE; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/graph/isolated_node_remove_pass.h b/mindspore/lite/tools/converter/optimizer/graph/isolated_node_remove_pass.h new file mode 100644 index 00000000000..293ccd8920f --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/graph/isolated_node_remove_pass.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H +#define MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H + +#include +#include "tools/converter/optimizer.h" + +namespace mindspore { +namespace lite { +class IsolatedNodeRemovePass : public GraphPass { + public: + IsolatedNodeRemovePass() = default; + + ~IsolatedNodeRemovePass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/graph/model_input_format_preprocess_pass.cc b/mindspore/lite/tools/converter/optimizer/graph/model_input_format_preprocess_pass.cc new file mode 100644 index 00000000000..caac57c18a8 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/graph/model_input_format_preprocess_pass.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "tools/converter/optimizer/graph/model_input_format_preprocess_pass.h" +#include "utils/log_adapter.h" +#include "tools/common/converter_op_utils.h" +#include "tools/common/node_util.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +STATUS ModelInputFormatPreProcessPass::Run(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + for (auto inputIndex : graph->inputIndex) { + if (graph->allTensors[inputIndex]->dims.size() == 4) { + std::vector tmpDims(graph->allTensors[inputIndex]->dims); + auto status = + NodeUtils::ConvertDims(schema::Format_NCHW, tmpDims, schema::Format_NHWC, &graph->allTensors[inputIndex]->dims); + if (status == RET_OK) { + graph->allTensors[inputIndex]->format = schema::Format_NHWC; + } else { + MS_LOG(ERROR) << "ConvertDims from NHWC to NCHW error: " << status; + return RET_ERROR; + } + } else { + graph->allTensors[inputIndex]->format = schema::Format_NHWC; + } + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/optimizer/graph/model_input_format_preprocess_pass.h b/mindspore/lite/tools/converter/optimizer/graph/model_input_format_preprocess_pass.h new file mode 100644 index 00000000000..187c93079ec --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/graph/model_input_format_preprocess_pass.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_MODEL_FORMAT_PREPROCESS_PASS_H +#define MINDSPORE_PREDICT_MODEL_FORMAT_PREPROCESS_PASS_H + +#include +#include "tools/converter/optimizer.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +class ModelInputFormatPreProcessPass : public GraphPass { + public: + ModelInputFormatPreProcessPass() = default; + + ~ModelInputFormatPreProcessPass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_MODEL_FORMAT_PREPROCESS_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/graph/topological_sort_pass.cc b/mindspore/lite/tools/converter/optimizer/graph/topological_sort_pass.cc new file mode 100644 index 00000000000..c9cdfdaa647 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/graph/topological_sort_pass.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "tools/converter/optimizer/graph/topological_sort_pass.h" +#include "tools/common/converter_op_utils.h" +#include "utils/log_adapter.h" +#include "src/common/utils.h" + +namespace mindspore { +namespace lite { +STATUS TopologicalSortPass::Run(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + std::vector> newNodes; + std::vector sinkedTensorIdxes; + // put all const tensor index into sinkedTensorIdxes + for (size_t i = 0; i < graph->allTensors.size(); i++) { + if (graph->allTensors.at(i)->nodeType == schema::NodeType_ValueNode) { + sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), i); + } + } + auto &oldNodes = graph->nodes; + std::queue> opQueue; + // put all non depend node into queue + for (auto &node : graph->nodes) { + if (IsNodeNonDepend(node, sinkedTensorIdxes)) { + sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), node->outputIndex.begin(), node->outputIndex.end()); + opQueue.push(std::move(node)); + } + } + // bfs + while (!opQueue.empty()) { + auto &node = opQueue.front(); + auto postNodeIdxes = GetOutputNodeIdx(*graph, *(node.get())); + for (auto postNodeIdx : postNodeIdxes) { + auto &postNode = oldNodes.at(postNodeIdx); + // check if postNode is non-depended + if (IsNodeNonDepend(postNode, sinkedTensorIdxes)) { + sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), postNode->outputIndex.begin(), postNode->outputIndex.end()); + opQueue.push(std::move(postNode)); + } + } + newNodes.emplace_back(std::move(node)); + opQueue.pop(); + } + if (newNodes.size() != oldNodes.size()) { + MS_LOG(ERROR) << "Unknow error in TopologicalSort, oldNodesSize: " << oldNodes.size() + << ", newNodesSize: " << newNodes.size(); + return RET_ERROR; + } + graph->nodes.swap(newNodes); + return RET_OK; +} + +bool TopologicalSortPass::IsNodeNonDepend(const std::unique_ptr &node, + const std::vector &sinkedTensorIdxes) { + for (auto inputIdx : node->inputIndex) { + if (!IsContain(sinkedTensorIdxes, size_t(inputIdx))) { + return false; + } + } + return true; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/core/ir/lite/param_value_lite.h b/mindspore/lite/tools/converter/optimizer/graph/topological_sort_pass.h similarity index 51% rename from mindspore/core/ir/lite/param_value_lite.h rename to mindspore/lite/tools/converter/optimizer/graph/topological_sort_pass.h index 0384d992a94..994648ab57b 100644 --- a/mindspore/core/ir/lite/param_value_lite.h +++ b/mindspore/lite/tools/converter/optimizer/graph/topological_sort_pass.h @@ -14,30 +14,29 @@ * limitations under the License. */ -#ifndef MINDSPORE_CORE_IR_LITE_PARAM_VALUE_LITE_H_ -#define MINDSPORE_CORE_IR_LITE_PARAM_VALUE_LITE_H_ +#ifndef MINDSPORE_PREDICT_TOPOLOGICAL_SORT_PASS_H +#define MINDSPORE_PREDICT_TOPOLOGICAL_SORT_PASS_H #include - -#include "ir/param_value.h" +#include +#include "mindspore/lite/tools/converter/optimizer.h" +#include "tools/common/graph_util.h" namespace mindspore { -class ParamValueLite : public ParamValue { +namespace lite { +class TopologicalSortPass : public GraphPass { public: - ParamValueLite() : tensor_addr_(nullptr), tensor_size_(0) {} - virtual ~ParamValueLite() = default; + TopologicalSortPass() = default; - size_t tensor_size() const { return tensor_size_; } - void set_tensor_size(size_t size) { tensor_size_ = size; } + ~TopologicalSortPass() override = default; - void *tensor_addr() const { return tensor_addr_; } - void set_tensor_addr(void *addr) { tensor_addr_ = addr; } + STATUS Run(schema::MetaGraphT *graph) override; private: - void *tensor_addr_; - size_t tensor_size_; + bool IsNodeNonDepend(const std::unique_ptr &node, const std::vector &sinkedTensorIdxes); }; - -using ParamValueLitePtr = std::shared_ptr; +} // namespace lite } // namespace mindspore -#endif // MINDSPORE_CORE_IR_LITE_PARAM_VALUE_LITE_H_ + +#endif // MINDSPORE_PREDICT_TOPOLOGICAL_SORT_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/graph/unused_node_remove_pass.cc b/mindspore/lite/tools/converter/optimizer/graph/unused_node_remove_pass.cc new file mode 100644 index 00000000000..c1159c23b51 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/graph/unused_node_remove_pass.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "mindspore/lite/tools/converter/optimizer/graph/unused_node_remove_pass.h" +#include "utils/log_adapter.h" +#include "tools/common/converter_op_utils.h" +#include "src/common/utils.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "mindspore/lite/schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +STATUS UnusedNodeRemovePass::Run(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + bool ifChanged = false; + for (size_t i = 0; i < graph->nodes.size(); i++) { + auto &node = graph->nodes.at(i); + if (node->primitive->value.type == schema::PrimitiveType_TupleGetItem) { + ifChanged = true; + auto status = IsolateOneWayNode(graph, i); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, subGraph: " << graph->name << ", node: " << node->name + << ", error: " << status; + return status; + } + } + } + return ifChanged ? RET_OK : RET_NO_CHANGE; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/optimizer/graph/unused_node_remove_pass.h b/mindspore/lite/tools/converter/optimizer/graph/unused_node_remove_pass.h new file mode 100644 index 00000000000..7716592a249 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/graph/unused_node_remove_pass.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_UNUSED_NODE_REMOVE_PASS_H +#define MINDSPORE_PREDICT_UNUSED_NODE_REMOVE_PASS_H + +#include +#include "tools/converter/optimizer.h" + +namespace mindspore { +namespace lite { +class UnusedNodeRemovePass : public GraphPass { + public: + UnusedNodeRemovePass() = default; + + ~UnusedNodeRemovePass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_UNUSED_NODE_REMOVE_PASS_H + diff --git a/mindspore/lite/tools/converter/optimizer/node/CMakeLists.txt b/mindspore/lite/tools/converter/optimizer/node/CMakeLists.txt new file mode 100755 index 00000000000..6288071c81a --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/node/CMakeLists.txt @@ -0,0 +1,3 @@ +add_library(node_mid OBJECT + ${CMAKE_CURRENT_SOURCE_DIR}/weight_format_pass.cc + ) diff --git a/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc new file mode 100644 index 00000000000..3400c884515 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc @@ -0,0 +1,394 @@ +/** + * Copyright 201+ Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/optimizer/node/weight_format_pass.h" +#include "tools/common/node_util.h" +#include "tools/common/tensor_util.h" + +namespace mindspore { +namespace lite { +int WeightFormatPass::Run(GraphNode *graphNode) { + MS_ASSERT(graphNode != nullptr); + auto status = ShapeFormatTrans(graphNode); + if (status != 0) { + MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status; + return status; + } + if (this->quantType == QuantType_AwareTrainning) { + status = QuantDataFormatTrans(graphNode); + if (status != 0) { + MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; + return status; + } + } else { + status = NonQuantDataFormatTrans(graphNode); + if (status != 0) { + MS_LOG(ERROR) << "NonQuantDataFormatTrans failed: " << status; + return status; + } + } + return 0; +} + +// void WeightFormatPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } + +void WeightFormatPass::SetFmkType(converter::FmkType fmkType) { this->fmkType = fmkType; } + +// pre set tensor format +// non quant, filterFormat: +// conv deconv depth dedepth +// caffe K(C/g)HW C(K/g)HW / / // todo with deconvOp +// tf HWCK HWKC HWCK HWKC +// onnx K(C/g)HW C(K/g)HW / / + +// awareing quant, filterFormat: +// conv deconv depth dedepth +// onnx KHWC ? CHWK ? +// tf HWCK ? HWCK ? +int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { + MS_ASSERT(graphNode != nullptr); + auto &subGraph = graphNode->subGraph; + auto &node = graphNode->opDef; + MS_ASSERT(subGraph != nullptr); + MS_ASSERT(node != nullptr); + auto opType = node->primitive->value.type; + if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && + opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) { + return 0; + } + MS_ASSERT(node->inputIndex.size() >= 2); + auto weightIndex = node->inputIndex.at(1); + MS_ASSERT(subGraph->allTensors.size() > weightIndex); + auto &weightTensor = subGraph->allTensors[weightIndex]; + auto &shape = weightTensor->dims; + MS_ASSERT(shape.size() == 4); + if (fmkType == converter::FmkType_CAFFE) { + switch (node->quantType) { + case QuantType_QUANT_NONE: { + if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { + weightTensor->format = schema::Format_KCHW; + } else { + MS_LOG(ERROR) << "Invalid opType: " << schema::EnumNamePrimitiveType(opType) + << ", node: " << node->name.c_str(); + return -1; + } + } break; + default: { + MS_LOG(ERROR) << "Invalid quantType: " << schema::EnumNameQuantType(node->quantType) + << ", node: " << node->name.c_str(); + return -1; + } + } + return 0; + } else if (fmkType == converter::FmkType_MS) { + switch (node->quantType) { + case QuantType_AwareTrainning: { + if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { + weightTensor->format = schema::Format_HWCK; + } else { + weightTensor->format = schema::Format_HWKC; + } + } break; + case QuantType_QUANT_NONE: { + // conv [filter_height, filter_width, in_channels, out_channels] + // depthwise [filter_height, filter_width, in_channels, channel_multiplier] + if (opType == schema::PrimitiveType_Conv2D) { + weightTensor->format = schema::Format_KCHW; + } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { + weightTensor->format = schema::Format_KCHW; + } else { + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name; + return -1; + } + } break; + default: { + MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); + return -1; + } + } + return 0; + } else if (fmkType == converter::FmkType_TF) { + switch (node->quantType) { + case QuantType_AwareTrainning: { + if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { + weightTensor->format = schema::Format_HWCK; + } else { + weightTensor->format = schema::Format_HWKC; + } + } break; + case QuantType_QUANT_NONE: { + // conv [filter_height, filter_width, in_channels, out_channels] + // depthwise [filter_height, filter_width, in_channels, channel_multiplier] + if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { + weightTensor->format = schema::Format_HWCK; + } else { + weightTensor->format = schema::Format_HWKC; + } + } break; + default: { + MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); + return -1; + } + } + return 0; + } else if (fmkType == converter::FmkType_TFLITE) { + switch (node->quantType) { + case QuantType_QUANT_NONE: + case QuantType_AwareTrainning: { + if (opType == schema::PrimitiveType_Conv2D) { + weightTensor->format = schema::Format_KHWC; + } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { + weightTensor->format = schema::Format_CHWK; + } else { + MS_LOG(ERROR) << "unsupport format"; + return -1; + } + } break; + default: { + MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); + return -1; + } + } + return 0; + } else if (fmkType == converter::FmkType_ONNX) { + switch (node->quantType) { + case QuantType_AwareTrainning: { + // sum up from current onnx quant models + if (opType == schema::PrimitiveType_Conv2D) { + weightTensor->format = schema::Format_KHWC; + } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { + weightTensor->format = schema::Format_CHWK; + } else { + MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); + return -1; + } + } break; + case QuantType_QUANT_NONE: { + // conv (K x C/group x kH x kW) group = 1 + // depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W) + // deconv (C x K/group x kH x kW) group = 1 + // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W) + if (opType == schema::PrimitiveType_Conv2D) { + weightTensor->format = schema::Format_KCHW; + } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { + weightTensor->format = schema::Format_CKHW; + } else if (opType == schema::PrimitiveType_DeConv2D) { + weightTensor->format = schema::Format_CKHW; + } else { + MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str(); + return -1; + } + } break; + default: { + MS_LOG(ERROR) << "Unsupported quantType: %d, node: " << node->quantType, node->name.c_str(); + return -1; + } + } + } else { + MS_LOG(ERROR) << "Invalid fmkType: %d, node: " << fmkType, node->name.c_str(); + return -1; + } + return 0; +} + +int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { + MS_ASSERT(graphNode != nullptr); + auto &subGraph = graphNode->subGraph; + auto &node = graphNode->opDef; + MS_ASSERT(subGraph != nullptr); + MS_ASSERT(node != nullptr); + auto opType = node->primitive->value.type; + if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && + opType != schema::PrimitiveType_DeConv2D) { + return 0; + } + + MS_ASSERT(node->inputIndex.size() >= 2); + auto weightIndex = node->inputIndex.at(1); + MS_ASSERT(subGraph->allTensors.size() > weightIndex); + auto &weightTensor = subGraph->allTensors[weightIndex]; + MS_ASSERT(weightTensor->dataType == -22); // DataType_DT_FLOAT + STATUS status; + if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK + if (weightTensor->format == schema::Format_KCHW) { // from caffe + if (weightTensor->dataType == -22) { // DataType_DT_UINT8) { + MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format + << weightTensor->dataType; + status = TransFilterFormat(weightTensor.get(), kKCHW2HWCK); + } else { + MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format + << weightTensor->dataType; + status = TransFilterFormat(weightTensor.get(), kKCHW2HWCK); + } + } else if (weightTensor->format == schema::Format_KHWC) { // from onnx + if (weightTensor->dataType == -22) { // DataType_DT_UINT8) { + status = TransFilterFormat(weightTensor.get(), kKHWC2HWCK); + } else { + status = TransFilterFormat(weightTensor.get(), kKHWC2HWCK); + } + } else if (weightTensor->format == schema::Format_HWCK) { // from tf + return 0; + } else { + MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; + return -1; + } + if (status == 0) { + node->primitive->value.AsConv2D()->format = schema::Format_NHWC; + weightTensor->format = schema::Format_HWCK; + } else { + MS_LOG(WARNING) << "TransFilter %sToHWCK failed, node : " + << (weightTensor->format == schema::Format_KCHW ? "KCHW" : "KHWC"), + node->name.c_str(); + // todo(00445839): consider varible weight condition + } + } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be HWCK + if (weightTensor->format == schema::Format_CKHW) { // from caffe + if (weightTensor->dataType == -22) { // DataType_DT_UINT8) { + MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format, + weightTensor->dataType; + status = TransFilterFormat(weightTensor.get(), kCKHW2HWCK); + } else { + MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format, + weightTensor->dataType; + status = TransFilterFormat(weightTensor.get(), kCKHW2HWCK); + } + + } else if (weightTensor->format == schema::Format_HWCK) { // from tf + return 0; + } else if (weightTensor->format == schema::Format_CHWK) { // from onnx + if (weightTensor->dataType == -22) { // DataType_DT_UINT8) { + status = TransFilterFormat(weightTensor.get(), kCHWK2HWCK); + } else { + status = TransFilterFormat(weightTensor.get(), kCHWK2HWCK); + } + } else { + MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; + return -1; + } + if (status == 0) { + node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW; + weightTensor->format = schema::Format_HWCK; + } else { + MS_LOG(WARNING) << "TransFilter %ToHWCK failed, node : " + << (weightTensor->format == schema::Format_CHWK ? "CHWK" : "CKHW"), + node->name.c_str(); + // todo(00445839): consider varible weight condition + } + } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be HWCK + node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW; + weightTensor->format = schema::Format_CKHW; + } + return 0; +} + +// inference needed filterFormat: +// conv deconv depth dedepth +// fp32 KCHW CKHW CKHW CKHW +int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { + MS_ASSERT(graphNode != nullptr); + auto &subGraph = graphNode->subGraph; + auto &node = graphNode->opDef; + MS_ASSERT(subGraph != nullptr); + MS_ASSERT(node != nullptr); + auto opType = node->primitive->value.type; + if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && + opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) { + return 0; + } + + MS_ASSERT(node->inputIndex.size() >= 2); + auto weightIndex = node->inputIndex.at(1); + MS_ASSERT(subGraph->allTensors.size() > weightIndex); + auto &weightTensor = subGraph->allTensors[weightIndex]; + if (weightTensor->dataType != TypeId::kNumberTypeFloat32) { + MS_LOG(ERROR) << "weight tensor data should be float"; + // return -1; + } + STATUS status = RET_OK; + if (opType == schema::PrimitiveType_Conv2D) { // weight should be KCHW + if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms + status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); + } else if (weightTensor->format == schema::Format_KHWC) { + status = RET_OK; + } else if (weightTensor->format == schema::Format_CHWK) { + status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); + } else { + MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; + return -1; + } + if (status == 0) { + node->primitive->value.AsConv2D()->format = schema::Format_NHWC; + weightTensor->format = schema::Format_KHWC; + } else { + MS_LOG(WARNING) << "TransFilter " << ((weightTensor->format == schema::Format_HWCK) ? "HWCK" : "NHWC") + << "ToKCHW failed, node : " << node->name.c_str(); + // todo(00445839): consider varible weight condition + } + } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be CKHW + if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms + status = TransFilterFormat(weightTensor.get(), kCKHW2KHWC); + } else if (weightTensor->format == schema::Format_KCHW) { + status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); + } else if (weightTensor->format == schema::Format_CHWK) { + status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); + } else { + MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; + return -1; + } + if (status == 0) { + node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC; + weightTensor->format = schema::Format_CKHW; + } else { + MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str(); + // todo(00445839): consider varible weight condition + } + } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KCHW + if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx + return 0; + } else if (weightTensor->format == schema::Format_HWKC) { // from tf + status = TransFilterFormat(weightTensor.get(), kHWKC2KCHW); + } else { + MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; + return -1; + } + if (status == 0) { + node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW; + weightTensor->format = schema::Format_KCHW; + } else { + MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str(); + // todo(00445839): consider varible weight condition + } + } else if (opType == schema::PrimitiveType_DeDepthwiseConv2D) { // weight should be CKHW + if (weightTensor->format == schema::Format_CKHW) { // from caffe + return 0; + } else if (weightTensor->format == schema::Format_HWKC) { // from tf or onnx + status = TransFilterFormat(weightTensor.get(), kHWKC2CKHW); + } else { + MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; + return -1; + } + if (status == 0) { + node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW; + weightTensor->format = schema::Format_CKHW; + } else { + MS_LOG(WARNING) << "TransFilter HWKCToCKHW failed, node : " << node->name.c_str(); + // todo(00445839): consider varible weight condition + } + } + return 0; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.h b/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.h new file mode 100644 index 00000000000..3a6d1217cf4 --- /dev/null +++ b/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_WEIGHT_FORMAT_PASS_H +#define MINDSPORE_PREDICT_WEIGHT_FORMAT_PASS_H + +#include "tools/converter/optimizer.h" +#include "tools/converter/converter_flags.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace lite { +class WeightFormatPass : public NodePass { + public: + WeightFormatPass() = default; + + ~WeightFormatPass() override = default; + + // void SetQuantType(QuantType quantType); + + void SetFmkType(converter::FmkType fmkType); + + int Run(GraphNode *graphNode) override; + + private: + // correct weightTensor->Format + int ShapeFormatTrans(GraphNode *graphNode); + + // transform weightTensor data and format + // if quant : conv transform dataFormat to NHWC, weight format to HWCK + // if quant : depth transform dataFormat to NCHW, weight format to CKHW + int QuantDataFormatTrans(GraphNode *graphNode); + + // if no quant : transform dataFormat to NCHW, weight format to KCHW/CKHW + int NonQuantDataFormatTrans(GraphNode *graphNode); + + private: + QuantType quantType = QuantType_QUANT_NONE; + converter::FmkType fmkType = converter::FmkType_TF; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_WEIGHT_FORMAT_PASS_H + diff --git a/mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt b/mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt new file mode 100644 index 00000000000..97406e5bc71 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt @@ -0,0 +1,52 @@ +add_library(caffe_parser_mid OBJECT + ${CMAKE_CURRENT_SOURCE_DIR}/caffe.pb.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_argmax_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_argmax_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_batchnorm_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_batchnorm_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_concat_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_concat_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_conv_base_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_conv_base_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_converter.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_converter.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_convolution_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_convolution_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_crop_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_crop_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_deconvolution_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_deconvolution_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_eltwise_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_eltwise_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_innerproduct_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_innerproduct_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_inspector.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_inspector.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_model_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_model_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser_registry.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser_registry.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_parse_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_parse_utils.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_pooling_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_pooling_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_power_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_power_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_prelu_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_prelu_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_relu_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_relu_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_reshape_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_reshape_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_scale_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_scale_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_sigmoid_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_sigmoid_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_softmax_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_softmax_parser.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_inspector.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_inspector.h + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_interp_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_interp_parser.h) diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe.proto b/mindspore/lite/tools/converter/parser/caffe/caffe.proto new file mode 100755 index 00000000000..75ae1aa357b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe.proto @@ -0,0 +1,1675 @@ +syntax = "proto2"; + +package caffe; + +// Specifies the shape (dimensions) of a Blob. +message BlobShape { + repeated int64 dim = 1 [packed = true]; +} + +message BlobProto { + optional BlobShape shape = 7; + repeated float data = 5 [packed = true]; + repeated float diff = 6 [packed = true]; + repeated double double_data = 8 [packed = true]; + repeated double double_diff = 9 [packed = true]; + + // 4D dimensions -- deprecated. Use "shape" instead. + optional int32 num = 1 [default = 0]; + optional int32 channels = 2 [default = 0]; + optional int32 height = 3 [default = 0]; + optional int32 width = 4 [default = 0]; +} + +// The BlobProtoVector is simply a way to pass multiple blobproto instances +// around. +message BlobProtoVector { + repeated BlobProto blobs = 1; +} + +message Datum { + optional int32 channels = 1; + optional int32 height = 2; + optional int32 width = 3; + // the actual image data, in bytes + optional bytes data = 4; + optional int32 label = 5; + // Optionally, the datum could also hold float data. + repeated float float_data = 6; + // If true data contains an encoded image that need to be decoded + optional bool encoded = 7 [default = false]; +} + +message FillerParameter { + // The filler type. + optional string type = 1 [default = 'constant']; + optional float value = 2 [default = 0]; // the value in constant filler + optional float min = 3 [default = 0]; // the min value in uniform filler + optional float max = 4 [default = 1]; // the max value in uniform filler + optional float mean = 5 [default = 0]; // the mean value in Gaussian filler + optional float std = 6 [default = 1]; // the std value in Gaussian filler + // The expected number of non-zero output weights for a given input in + // Gaussian filler -- the default -1 means don't perform sparsification. + optional int32 sparse = 7 [default = -1]; + // Normalize the filler variance by fan_in, fan_out, or their average. + // Applies to 'xavier' and 'msra' fillers. + enum VarianceNorm { + FAN_IN = 0; + FAN_OUT = 1; + AVERAGE = 2; + } + optional VarianceNorm variance_norm = 8 [default = FAN_IN]; +} + +message NetParameter { + optional string name = 1; // consider giving the network a name + // DEPRECATED. See InputParameter. The input blobs to the network. + repeated string input = 3; + // DEPRECATED. See InputParameter. The shape of the input blobs. + repeated BlobShape input_shape = 8; + + // 4D input dimensions -- deprecated. Use "input_shape" instead. + // If specified, for each input blob there should be four + // values specifying the num, channels, height and width of the input blob. + // Thus, there should be a total of (4 * #input) numbers. + repeated int32 input_dim = 4; + + // Whether the network will force every layer to carry out backward operation. + // If set False, then whether to carry out backward is determined + // automatically according to the net structure and learning rates. + optional bool force_backward = 5 [default = false]; + // The current "state" of the network, including the phase, level, and stage. + // Some layers may be included/excluded depending on this state and the states + // specified in the layers' include and exclude fields. + optional NetState state = 6; + + // Print debugging information about results while running Net::Forward, + // Net::Backward, and Net::Update. + optional bool debug_info = 7 [default = false]; + + // The layers that make up the net. Each of their configurations, including + // connectivity and behavior, is specified as a LayerParameter. + repeated LayerParameter layer = 100; // ID 100 so layers are printed last. + + // DEPRECATED: use 'layer' instead. + repeated V1LayerParameter layers = 2; +} + +// NOTE +// Update the next available ID when you add a new SolverParameter field. +// +// SolverParameter next available ID: 42 (last added: layer_wise_reduce) +message SolverParameter { + ////////////////////////////////////////////////////////////////////////////// + // Specifying the train and test networks + // + // Exactly one train net must be specified using one of the following fields: + // train_net_param, train_net, net_param, net + // One or more test nets may be specified using any of the following fields: + // test_net_param, test_net, net_param, net + // If more than one test net field is specified (e.g., both net and + // test_net are specified), they will be evaluated in the field order given + // above: (1) test_net_param, (2) test_net, (3) net_param/net. + // A test_iter must be specified for each test_net. + // A test_level and/or a test_stage may also be specified for each test_net. + ////////////////////////////////////////////////////////////////////////////// + + // Proto filename for the train net, possibly combined with one or more + // test nets. + optional string net = 24; + // Inline train net param, possibly combined with one or more test nets. + optional NetParameter net_param = 25; + + optional string train_net = 1; // Proto filename for the train net. + repeated string test_net = 2; // Proto filenames for the test nets. + optional NetParameter train_net_param = 21; // Inline train net params. + repeated NetParameter test_net_param = 22; // Inline test net params. + + // The states for the train/test nets. Must be unspecified or + // specified once per net. + // + // By default, all states will have solver = true; + // train_state will have phase = TRAIN, + // and all test_state's will have phase = TEST. + // Other defaults are set according to the NetState defaults. + optional NetState train_state = 26; + repeated NetState test_state = 27; + + // The number of iterations for each test net. + repeated int32 test_iter = 3; + + // The number of iterations between two testing phases. + optional int32 test_interval = 4 [default = 0]; + optional bool test_compute_loss = 19 [default = false]; + // If true, run an initial test pass before the first iteration, + // ensuring memory availability and printing the starting value of the loss. + optional bool test_initialization = 32 [default = true]; + optional float base_lr = 5; // The base learning rate + // the number of iterations between displaying info. If display = 0, no info + // will be displayed. + optional int32 display = 6; + // Display the loss averaged over the last average_loss iterations + optional int32 average_loss = 33 [default = 1]; + optional int32 max_iter = 7; // the maximum number of iterations + // accumulate gradients over `iter_size` x `batch_size` instances + optional int32 iter_size = 36 [default = 1]; + + // The learning rate decay policy. The currently implemented learning rate + // policies are as follows: + // - fixed: always return base_lr. + // - step: return base_lr * gamma ^ (floor(iter / step)) + // - exp: return base_lr * gamma ^ iter + // - inv: return base_lr * (1 + gamma * iter) ^ (- power) + // - multistep: similar to step but it allows non uniform steps defined by + // stepvalue + // - poly: the effective learning rate follows a polynomial decay, to be + // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) + // - sigmoid: the effective learning rate follows a sigmod decay + // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) + // + // where base_lr, max_iter, gamma, step, stepvalue and power are defined + // in the solver parameter protocol buffer, and iter is the current iteration. + optional string lr_policy = 8; + optional float gamma = 9; // The parameter to compute the learning rate. + optional float power = 10; // The parameter to compute the learning rate. + optional float momentum = 11; // The momentum value. + optional float weight_decay = 12; // The weight decay. + // regularization types supported: L1 and L2 + // controlled by weight_decay + optional string regularization_type = 29 [default = "L2"]; + // the stepsize for learning rate policy "step" + optional int32 stepsize = 13; + // the stepsize for learning rate policy "multistep" + repeated int32 stepvalue = 34; + + // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm, + // whenever their actual L2 norm is larger. + optional float clip_gradients = 35 [default = -1]; + + optional int32 snapshot = 14 [default = 0]; // The snapshot interval + optional string snapshot_prefix = 15; // The prefix for the snapshot. + // whether to snapshot diff in the results or not. Snapshotting diff will help + // debugging but the final protocol buffer size will be much larger. + optional bool snapshot_diff = 16 [default = false]; + enum SnapshotFormat { + HDF5 = 0; + BINARYPROTO = 1; + } + optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO]; + // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default. + enum SolverMode { + CPU = 0; + GPU = 1; + } + optional SolverMode solver_mode = 17 [default = GPU]; + // the device_id will that be used in GPU mode. Use device_id = 0 in default. + optional int32 device_id = 18 [default = 0]; + // If non-negative, the seed with which the Solver will initialize the Caffe + // random number generator -- useful for reproducible results. Otherwise, + // (and by default) initialize using a seed derived from the system clock. + optional int64 random_seed = 20 [default = -1]; + + // type of the solver + optional string type = 40 [default = "SGD"]; + + // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam + optional float delta = 31 [default = 1e-8]; + // parameters for the Adam solver + optional float momentum2 = 39 [default = 0.999]; + + // RMSProp decay value + // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t) + optional float rms_decay = 38 [default = 0.99]; + + // If true, print information about the state of the net that may help with + // debugging learning problems. + optional bool debug_info = 23 [default = false]; + + // If false, don't save a snapshot after training finishes. + optional bool snapshot_after_train = 28 [default = true]; + + // DEPRECATED: old solver enum types, use string instead + enum SolverType { + SGD = 0; + NESTEROV = 1; + ADAGRAD = 2; + RMSPROP = 3; + ADADELTA = 4; + ADAM = 5; + } + // DEPRECATED: use type instead of solver_type + optional SolverType solver_type = 30 [default = SGD]; + + // Overlap compute and communication for data parallel training + optional bool layer_wise_reduce = 41 [default = true]; +} + +// A message that stores the solver snapshots +message SolverState { + optional int32 iter = 1; // The current iteration + optional string learned_net = 2; // The file that stores the learned net. + repeated BlobProto history = 3; // The history for sgd solvers + optional int32 current_step = 4 [default = 0]; // The current step for learning rate +} + +enum Phase { + TRAIN = 0; + TEST = 1; +} + +message NetState { + optional Phase phase = 1 [default = TEST]; + optional int32 level = 2 [default = 0]; + repeated string stage = 3; +} + +message NetStateRule { + // Set phase to require the NetState have a particular phase (TRAIN or TEST) + // to meet this rule. + optional Phase phase = 1; + + // Set the minimum and/or maximum levels in which the layer should be used. + // Leave undefined to meet the rule regardless of level. + optional int32 min_level = 2; + optional int32 max_level = 3; + + // Customizable sets of stages to include or exclude. + // The net must have ALL of the specified stages and NONE of the specified + // "not_stage"s to meet the rule. + // (Use multiple NetStateRules to specify conjunctions of stages.) + repeated string stage = 4; + repeated string not_stage = 5; +} + +// Specifies training parameters (multipliers on global learning constants, +// and the name and other settings used for weight sharing). +message ParamSpec { + // The names of the parameter blobs -- useful for sharing parameters among + // layers, but never required otherwise. To share a parameter between two + // layers, give it a (non-empty) name. + optional string name = 1; + + // Whether to require shared weights to have the same shape, or just the same + // count -- defaults to STRICT if unspecified. + optional DimCheckMode share_mode = 2; + enum DimCheckMode { + // STRICT (default) requires that num, channels, height, width each match. + STRICT = 0; + // PERMISSIVE requires only the count (num*channels*height*width) to match. + PERMISSIVE = 1; + } + + // The multiplier on the global learning rate for this parameter. + optional float lr_mult = 3 [default = 1.0]; + + // The multiplier on the global weight decay for this parameter. + optional float decay_mult = 4 [default = 1.0]; +} + +// NOTE +// Update the next available ID when you add a new LayerParameter field. +// +// LayerParameter next available layer-specific ID: 151 (last added: smooth_l1_loss_param) +message LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the layer type + repeated string bottom = 3; // the name of each bottom blob + repeated string top = 4; // the name of each top blob + + // The train / test phase for computation. + optional Phase phase = 10; + + // The amount of weight to assign each top blob in the objective. + // Each layer assigns a default value, usually of either 0 or 1, + // to each top blob. + repeated float loss_weight = 5; + + // Specifies training parameters (multipliers on global learning constants, + // and the name and other settings used for weight sharing). + repeated ParamSpec param = 6; + + // The blobs containing the numeric parameters of the layer. + repeated BlobProto blobs = 7; + + // Specifies whether to backpropagate to each bottom. If unspecified, + // Caffe will automatically infer whether each input needs backpropagation + // to compute parameter gradients. If set to true for some inputs, + // backpropagation to those inputs is forced; if set false for some inputs, + // backpropagation to those inputs is skipped. + // + // The size must be either 0 or equal to the number of bottoms. + repeated bool propagate_down = 11; + + // Rules controlling whether and when a layer is included in the network, + // based on the current NetState. You may specify a non-zero number of rules + // to include OR exclude, but not both. If no include or exclude rules are + // specified, the layer is always included. If the current NetState meets + // ANY (i.e., one or more) of the specified rules, the layer is + // included/excluded. + repeated NetStateRule include = 8; + repeated NetStateRule exclude = 9; + + // Parameters for data pre-processing. + optional TransformationParameter transform_param = 100; + + // Parameters shared by loss layers. + optional LossParameter loss_param = 101; + + // Layer type-specific parameters. + // + // Note: certain layers may have more than one computational engine + // for their implementation. These layers include an Engine type and + // engine parameter for selecting the implementation. + // The default for the engine is set by the ENGINE switch at compile-time. + optional AccuracyParameter accuracy_param = 102; + optional ArgMaxParameter argmax_param = 103; + optional BatchNormParameter batch_norm_param = 139; + optional BiasParameter bias_param = 141; + optional ConcatParameter concat_param = 104; + optional ContrastiveLossParameter contrastive_loss_param = 105; + optional ConvolutionParameter convolution_param = 106; + optional CropParameter crop_param = 144; + optional DataParameter data_param = 107; + optional DetectionOutputParameter detection_output_param = 150; + optional DropoutParameter dropout_param = 108; + optional DummyDataParameter dummy_data_param = 109; + optional EltwiseParameter eltwise_param = 110; + optional ELUParameter elu_param = 140; + optional EmbedParameter embed_param = 137; + optional ExpParameter exp_param = 111; + optional FlattenParameter flatten_param = 135; + optional HDF5DataParameter hdf5_data_param = 112; + optional HDF5OutputParameter hdf5_output_param = 113; + optional HingeLossParameter hinge_loss_param = 114; + optional ImageDataParameter image_data_param = 115; + optional InfogainLossParameter infogain_loss_param = 116; + optional InnerProductParameter inner_product_param = 117; + optional InputParameter input_param = 143; + optional LogParameter log_param = 134; + optional LRNParameter lrn_param = 118; + optional MemoryDataParameter memory_data_param = 119; + optional MVNParameter mvn_param = 120; + optional ParameterParameter parameter_param = 145; + optional PoolingParameter pooling_param = 121; + optional PowerParameter power_param = 122; + optional PReLUParameter prelu_param = 131; + optional ProposalParameter proposal_param = 900; + optional PythonParameter python_param = 130; + optional RecurrentParameter recurrent_param = 146; + optional ReductionParameter reduction_param = 136; + optional ReLUParameter relu_param = 123; + optional ReshapeParameter reshape_param = 133; + optional ROIPoolingParameter roi_pooling_param = 147; + optional ScaleParameter scale_param = 142; + optional SigmoidParameter sigmoid_param = 124; + optional SmoothL1LossParameter smooth_l1_loss_param = 148; + optional SoftmaxParameter softmax_param = 125; + optional SPPParameter spp_param = 132; + optional SliceParameter slice_param = 126; + optional TanHParameter tanh_param = 127; + optional ThresholdParameter threshold_param = 128; + optional TileParameter tile_param = 138; + optional WindowDataParameter window_data_param = 129; + optional PermuteParameter permute_param = 202; + optional PriorBoxParameter prior_box_param = 203; + optional NormalizeParameter norm_param = 206; + optional PSROIPoolingParameter psroi_pooling_param = 207; + optional FreespaceExtractParameter freespace_extract_param = 151; + optional PostprocessParameter postprocess_param = 152; + optional SpatialTransformParameter spatial_transform_param = 153; + optional ROIAlignParameter roi_align_param = 154; + optional ReorgParameter reorg_param = 155; + optional RegionParameter region_param = 156; + optional ReverseParameter reverse_param = 157; + optional InterpParameter interp_param = 158; + optional ShuffleChannelParameter shuffle_channel_param = 159; + optional UpsampleParameter upsample_param = 160; +} + +// Message that stores parameters used to apply transformation +// to the data layer's data +message TransformationParameter { + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 1 [default = 1]; + // Specify if we want to randomly mirror data. + optional bool mirror = 2 [default = false]; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 3 [default = 0]; + // mean_file and mean_value cannot be specified at the same time + optional string mean_file = 4; + // if specified can be repeated once (would substract it from all the channels) + // or can be repeated the same number of times as channels + // (would subtract them from the corresponding channel) + repeated float mean_value = 5; + // Force the decoded image to have 3 color channels. + optional bool force_color = 6 [default = false]; + // Force the decoded image to have 1 color channels. + optional bool force_gray = 7 [default = false]; +} + +// Message that stores parameters shared by loss layers +message LossParameter { + // If specified, ignore instances with the given label. + optional int32 ignore_label = 1; + // How to normalize the loss for loss layers that aggregate across batches, + // spatial dimensions, or other dimensions. Currently only implemented in + // SoftmaxWithLoss and SigmoidCrossEntropyLoss layers. + enum NormalizationMode { + // Divide by the number of examples in the batch times spatial dimensions. + // Outputs that receive the ignore label will NOT be ignored in computing + // the normalization factor. + FULL = 0; + // Divide by the total number of output locations that do not take the + // ignore_label. If ignore_label is not set, this behaves like FULL. + VALID = 1; + // Divide by the batch size. + BATCH_SIZE = 2; + // Do not normalize the loss. + NONE = 3; + } + // For historical reasons, the default normalization for + // SigmoidCrossEntropyLoss is BATCH_SIZE and *not* VALID. + optional NormalizationMode normalization = 3 [default = VALID]; + // Deprecated. Ignored if normalization is specified. If normalization + // is not specified, then setting this to false will be equivalent to + // normalization = BATCH_SIZE to be consistent with previous behavior. + optional bool normalize = 2; +} + +// Messages that store parameters used by individual layer types follow, in +// alphabetical order. + +message AccuracyParameter { + // When computing accuracy, count as correct by comparing the true label to + // the top k scoring classes. By default, only compare to the top scoring + // class (i.e. argmax). + optional uint32 top_k = 1 [default = 1]; + + // The "label" axis of the prediction blob, whose argmax corresponds to the + // predicted label -- may be negative to index from the end (e.g., -1 for the + // last axis). For example, if axis == 1 and the predictions are + // (N x C x H x W), the label blob is expected to contain N*H*W ground truth + // labels with integer values in {0, 1, ..., C-1}. + optional int32 axis = 2 [default = 1]; + + // If specified, ignore instances with the given label. + optional int32 ignore_label = 3; +} + +message ArgMaxParameter { + // If true produce pairs (argmax, maxval) + optional bool out_max_val = 1 [default = false]; + optional uint32 top_k = 2 [default = 1]; + // The axis along which to maximise -- may be negative to index from the + // end (e.g., -1 for the last axis). + // By default ArgMaxLayer maximizes over the flattened trailing dimensions + // for each index of the first / num dimension. + optional int32 axis = 3; +} + +message ConcatParameter { + // The axis along which to concatenate -- may be negative to index from the + // end (e.g., -1 for the last axis). Other axes must have the + // same dimension for all the bottom blobs. + // By default, ConcatLayer concatenates blobs along the "channels" axis (1). + optional int32 axis = 2 [default = 1]; + + // DEPRECATED: alias for "axis" -- does not support negative indexing. + optional uint32 concat_dim = 1 [default = 1]; +} + +message BatchNormParameter { + // If false, normalization is performed over the current mini-batch + // and global statistics are accumulated (but not yet used) by a moving + // average. + // If true, those accumulated mean and variance values are used for the + // normalization. + // By default, it is set to false when the network is in the training + // phase and true when the network is in the testing phase. + optional bool use_global_stats = 1; + // What fraction of the moving average remains each iteration? + // Smaller values make the moving average decay faster, giving more + // weight to the recent values. + // Each iteration updates the moving average @f$S_{t-1}@f$ with the + // current mean @f$ Y_t @f$ by + // @f$ S_t = (1-\beta)Y_t + \beta \cdot S_{t-1} @f$, where @f$ \beta @f$ + // is the moving_average_fraction parameter. + optional float moving_average_fraction = 2 [default = .999]; + // Small value to add to the variance estimate so that we don't divide by + // zero. + optional float eps = 3 [default = 1e-5]; +} + +message BiasParameter { + // The first axis of bottom[0] (the first input Blob) along which to apply + // bottom[1] (the second input Blob). May be negative to index from the end + // (e.g., -1 for the last axis). + // + // For example, if bottom[0] is 4D with shape 100x3x40x60, the output + // top[0] will have the same shape, and bottom[1] may have any of the + // following shapes (for the given value of axis): + // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 + // (axis == 1 == -3) 3; 3x40; 3x40x60 + // (axis == 2 == -2) 40; 40x60 + // (axis == 3 == -1) 60 + // Furthermore, bottom[1] may have the empty shape (regardless of the value of + // "axis") -- a scalar bias. + optional int32 axis = 1 [default = 1]; + + // (num_axes is ignored unless just one bottom is given and the bias is + // a learned parameter of the layer. Otherwise, num_axes is determined by the + // number of axes by the second bottom.) + // The number of axes of the input (bottom[0]) covered by the bias + // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. + // Set num_axes := 0, to add a zero-axis Blob: a scalar. + optional int32 num_axes = 2 [default = 1]; + + // (filler is ignored unless just one bottom is given and the bias is + // a learned parameter of the layer.) + // The initialization for the learned bias parameter. + // Default is the zero (0) initialization, resulting in the BiasLayer + // initially performing the identity operation. + optional FillerParameter filler = 3; +} + +message ContrastiveLossParameter { + // margin for dissimilar pair + optional float margin = 1 [default = 1.0]; + // The first implementation of this cost did not exactly match the cost of + // Hadsell et al 2006 -- using (margin - d^2) instead of (margin - d)^2. + // legacy_version = false (the default) uses (margin - d)^2 as proposed in the + // Hadsell paper. New models should probably use this version. + // legacy_version = true uses (margin - d^2). This is kept to support / + // reproduce existing models and results + optional bool legacy_version = 2 [default = false]; +} + +message ConvolutionParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in all spatial dimensions, or once per spatial dimension. + repeated uint32 pad = 3; // The padding size; defaults to 0 + repeated uint32 kernel_size = 4; // The kernel size + repeated uint32 stride = 6; // The stride; defaults to 1 + // Factor used to dilate the kernel, (implicitly) zero-filling the resulting + // holes. (Kernel dilation is sometimes referred to by its use in the + // algorithme à trous from Holschneider et al. 1987.) + repeated uint32 dilation = 18; // The dilation; defaults to 1 + + // For 2D convolution only, the *_h and *_w versions may also be used to + // specify both spatial dimensions. + optional uint32 pad_h = 9 [default = 0]; // The padding height (2D only) + optional uint32 pad_w = 10 [default = 0]; // The padding width (2D only) + optional uint32 kernel_h = 11; // The kernel height (2D only) + optional uint32 kernel_w = 12; // The kernel width (2D only) + optional uint32 stride_h = 13; // The stride height (2D only) + optional uint32 stride_w = 14; // The stride width (2D only) + + optional uint32 group = 5 [default = 1]; // The group size for group conv + + optional FillerParameter weight_filler = 7; // The filler for the weight + optional FillerParameter bias_filler = 8; // The filler for the bias + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 15 [default = DEFAULT]; + + // The axis to interpret as "channels" when performing convolution. + // Preceding dimensions are treated as independent inputs; + // succeeding dimensions are treated as "spatial". + // With (N, C, H, W) inputs, and axis == 1 (the default), we perform + // N independent 2D convolutions, sliding C-channel (or (C/g)-channels, for + // groups g>1) filters across the spatial axes (H, W) of the input. + // With (N, C, D, H, W) inputs, and axis == 1, we perform + // N independent 3D convolutions, sliding (C/g)-channels + // filters across the spatial axes (D, H, W) of the input. + optional int32 axis = 16 [default = 1]; + + // Whether to force use of the general ND convolution, even if a specific + // implementation for blobs of the appropriate number of spatial dimensions + // is available. (Currently, there is only a 2D-specific convolution + // implementation; for input blobs with num_axes != 2, this option is + // ignored and the ND implementation will be used.) + optional bool force_nd_im2col = 17 [default = false]; +} + +message CropParameter { + // To crop, elements of the first bottom are selected to fit the dimensions + // of the second, reference bottom. The crop is configured by + // - the crop `axis` to pick the dimensions for cropping + // - the crop `offset` to set the shift for all/each dimension + // to align the cropped bottom with the reference bottom. + // All dimensions up to but excluding `axis` are preserved, while + // the dimensions including and trailing `axis` are cropped. + // If only one `offset` is set, then all dimensions are offset by this amount. + // Otherwise, the number of offsets must equal the number of cropped axes to + // shift the crop in each dimension accordingly. + // Note: standard dimensions are N,C,H,W so the default is a spatial crop, + // and `axis` may be negative to index from the end (e.g., -1 for the last + // axis). + optional int32 axis = 1 [default = 2]; + repeated uint32 offset = 2; +} + +message DataParameter { + enum DB { + LEVELDB = 0; + LMDB = 1; + } + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + // DEPRECATED. Each solver accesses a different subset of the database. + optional uint32 rand_skip = 7 [default = 0]; + optional DB backend = 8 [default = LEVELDB]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; + // Force the encoded image to have 3 color channels + optional bool force_encoded_color = 9 [default = false]; + // Prefetch queue (Increase if data feeding bandwidth varies, within the + // limit of device memory for GPU training) + optional uint32 prefetch = 10 [default = 4]; +} + +message DropoutParameter { + optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio + optional bool scale_train = 2 [default = true]; // scale train or test phase +} + +// DummyDataLayer fills any number of arbitrarily shaped blobs with random +// (or constant) data generated by "Fillers" (see "message FillerParameter"). +message DummyDataParameter { + // This layer produces N >= 1 top blobs. DummyDataParameter must specify 1 or N + // shape fields, and 0, 1 or N data_fillers. + // + // If 0 data_fillers are specified, ConstantFiller with a value of 0 is used. + // If 1 data_filler is specified, it is applied to all top blobs. If N are + // specified, the ith is applied to the ith top blob. + repeated FillerParameter data_filler = 1; + repeated BlobShape shape = 6; + + // 4D dimensions -- deprecated. Use "shape" instead. + repeated uint32 num = 2; + repeated uint32 channels = 3; + repeated uint32 height = 4; + repeated uint32 width = 5; +} + +message EltwiseParameter { + enum EltwiseOp { + PROD = 0; + SUM = 1; + MAX = 2; + } + optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation + repeated float coeff = 2; // blob-wise coefficient for SUM operation + + // Whether to use an asymptotically slower (for >2 inputs) but stabler method + // of computing the gradient for the PROD operation. (No effect for SUM op.) + optional bool stable_prod_grad = 3 [default = true]; +} + +// Message that stores parameters used by ELULayer +message ELUParameter { + // Described in: + // Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate + // Deep Network Learning by Exponential Linear Units (ELUs). arXiv + optional float alpha = 1 [default = 1]; +} + +// Message that stores parameters used by EmbedLayer +message EmbedParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + // The input is given as integers to be interpreted as one-hot + // vector indices with dimension num_input. Hence num_input should be + // 1 greater than the maximum possible input value. + optional uint32 input_dim = 2; + + optional bool bias_term = 3 [default = true]; // Whether to use a bias term + optional FillerParameter weight_filler = 4; // The filler for the weight + optional FillerParameter bias_filler = 5; // The filler for the bias + +} + +// Message that stores parameters used by ExpLayer +message ExpParameter { + // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0. + // Or if base is set to the default (-1), base is set to e, + // so y = exp(shift + scale * x). + optional float base = 1 [default = -1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +/// Message that stores parameters used by FlattenLayer +message FlattenParameter { + // The first axis to flatten: all preceding axes are retained in the output. + // May be negative to index from the end (e.g., -1 for the last axis). + optional int32 axis = 1 [default = 1]; + + // The last axis to flatten: all following axes are retained in the output. + // May be negative to index from the end (e.g., the default -1 for the last + // axis). + optional int32 end_axis = 2 [default = -1]; +} + +// Message that stores parameters used by HDF5DataLayer +message HDF5DataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 2; + + // Specify whether to shuffle the data. + // If shuffle == true, the ordering of the HDF5 files is shuffled, + // and the ordering of data within any given HDF5 file is shuffled, + // but data between different files are not interleaved; all of a file's + // data are output (in a random order) before moving onto another file. + optional bool shuffle = 3 [default = false]; +} + +message HDF5OutputParameter { + optional string file_name = 1; +} + +message HingeLossParameter { + enum Norm { + L1 = 1; + L2 = 2; + } + // Specify the Norm to use L1 or L2 + optional Norm norm = 1 [default = L1]; +} + +message ImageDataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4 [default = 1]; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 7 [default = 0]; + // Whether or not ImageLayer should shuffle the list of files at every epoch. + optional bool shuffle = 8 [default = false]; + // It will also resize images if new_height or new_width are not zero. + optional uint32 new_height = 9 [default = 0]; + optional uint32 new_width = 10 [default = 0]; + // Specify if the images are color or gray + optional bool is_color = 11 [default = true]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; + optional string root_folder = 12 [default = ""]; +} + +message InfogainLossParameter { + // Specify the infogain matrix source. + optional string source = 1; + optional int32 axis = 2 [default = 1]; // axis of prob +} + +message InnerProductParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 3; // The filler for the weight + optional FillerParameter bias_filler = 4; // The filler for the bias + + // The first axis to be lumped into a single inner product computation; + // all preceding axes are retained in the output. + // May be negative to index from the end (e.g., -1 for the last axis). + optional int32 axis = 5 [default = 1]; + // Specify whether to transpose the weight matrix or not. + // If transpose == true, any operations will be performed on the transpose + // of the weight matrix. The weight matrix itself is not going to be transposed + // but rather the transfer flag of operations will be toggled accordingly. + optional bool transpose = 6 [default = false]; +} + +message InputParameter { + // This layer produces N >= 1 top blob(s) to be assigned manually. + // Define N shapes to set a shape for each top. + // Define 1 shape to set the same shape for every top. + // Define no shape to defer to reshaping manually. + repeated BlobShape shape = 1; +} + +// Message that stores parameters used by LogLayer +message LogParameter { + // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0. + // Or if base is set to the default (-1), base is set to e, + // so y = ln(shift + scale * x) = log_e(shift + scale * x) + optional float base = 1 [default = -1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +// Message that stores parameters used by LRNLayer +message LRNParameter { + optional uint32 local_size = 1 [default = 5]; + optional float alpha = 2 [default = 1.]; + optional float beta = 3 [default = 0.75]; + enum NormRegion { + ACROSS_CHANNELS = 0; + WITHIN_CHANNEL = 1; + } + optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS]; + optional float k = 5 [default = 1.]; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 6 [default = DEFAULT]; +} + +message MemoryDataParameter { + optional uint32 batch_size = 1; + optional uint32 channels = 2; + optional uint32 height = 3; + optional uint32 width = 4; +} + +message MVNParameter { + // This parameter can be set to false to normalize mean only + optional bool normalize_variance = 1 [default = true]; + + // This parameter can be set to true to perform DNN-like MVN + optional bool across_channels = 2 [default = false]; + + // Epsilon for not dividing by zero while normalizing variance + optional float eps = 3 [default = 1e-9]; +} + +message ParameterParameter { + optional BlobShape shape = 1; +} + +message PoolingParameter { + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 1 [default = MAX]; // The pooling method + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X) + optional uint32 pad_h = 9 [default = 0]; // The padding height + optional uint32 pad_w = 10 [default = 0]; // The padding width + optional uint32 kernel_size = 2; // The kernel size (square) + optional uint32 kernel_h = 5; // The kernel height + optional uint32 kernel_w = 6; // The kernel width + optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X) + optional uint32 stride_h = 7; // The stride height + optional uint32 stride_w = 8; // The stride width + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 11 [default = DEFAULT]; + // If global_pooling then it will pool over the size of the bottom by doing + // kernel_h = bottom->height and kernel_w = bottom->width + optional bool global_pooling = 12 [default = false]; + optional bool ceil_mode = 13 [default = true]; + // How to calculate the output size - using ceil (default) or floor rounding. + enum RoundMode { + CEIL = 0; + FLOOR = 1; + } + optional RoundMode round_mode = 14 [default = CEIL]; +} + +message PowerParameter { + // PowerLayer computes outputs y = (shift + scale * x) ^ power. + optional float power = 1 [default = 1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +message PythonParameter { + optional string module = 1; + optional string layer = 2; + // This value is set to the attribute `param_str` of the `PythonLayer` object + // in Python before calling the `setup()` method. This could be a number, + // string, dictionary in Python dict format, JSON, etc. You may parse this + // string in `setup` method and use it in `forward` and `backward`. + optional string param_str = 3 [default = '']; + // Whether this PythonLayer is shared among worker solvers during data parallelism. + // If true, each worker solver sequentially run forward from this layer. + // This value should be set true if you are using it as a data layer. + optional bool share_in_parallel = 4 [default = false]; +} + +// Message that stores parameters used by RecurrentLayer +message RecurrentParameter { + // The dimension of the output (and usually hidden state) representation -- + // must be explicitly set to non-zero. + optional uint32 num_output = 1 [default = 0]; + + optional FillerParameter weight_filler = 2; // The filler for the weight + optional FillerParameter bias_filler = 3; // The filler for the bias + + // Whether to enable displaying debug_info in the unrolled recurrent net. + optional bool debug_info = 4 [default = false]; + + // Whether to add as additional inputs (bottoms) the initial hidden state + // blobs, and add as additional outputs (tops) the final timestep hidden state + // blobs. The number of additional bottom/top blobs required depends on the + // recurrent architecture -- e.g., 1 for RNNs, 2 for LSTMs. + optional bool expose_hidden = 5 [default = false]; +} + +// Message that stores parameters used by ReductionLayer +message ReductionParameter { + enum ReductionOp { + SUM = 1; + ASUM = 2; + SUMSQ = 3; + MEAN = 4; + } + + optional ReductionOp operation = 1 [default = SUM]; // reduction operation + + // The first axis to reduce to a scalar -- may be negative to index from the + // end (e.g., -1 for the last axis). + // (Currently, only reduction along ALL "tail" axes is supported; reduction + // of axis M through N, where N < num_axes - 1, is unsupported.) + // Suppose we have an n-axis bottom Blob with shape: + // (d0, d1, d2, ..., d(m-1), dm, d(m+1), ..., d(n-1)). + // If axis == m, the output Blob will have shape + // (d0, d1, d2, ..., d(m-1)), + // and the ReductionOp operation is performed (d0 * d1 * d2 * ... * d(m-1)) + // times, each including (dm * d(m+1) * ... * d(n-1)) individual data. + // If axis == 0 (the default), the output Blob always has the empty shape + // (count 1), performing reduction across the entire input -- + // often useful for creating new loss functions. + optional int32 axis = 2 [default = 0]; + + optional float coeff = 3 [default = 1.0]; // coefficient for output +} + +// Message that stores parameters used by ReLULayer +message ReLUParameter { + // Allow non-zero slope for negative inputs to speed up optimization + // Described in: + // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities + // improve neural network acoustic models. In ICML Workshop on Deep Learning + // for Audio, Speech, and Language Processing. + optional float negative_slope = 1 [default = 0]; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 2 [default = DEFAULT]; +} + +message ReshapeParameter { + // Specify the output dimensions. If some of the dimensions are set to 0, + // the corresponding dimension from the bottom layer is used (unchanged). + // Exactly one dimension may be set to -1, in which case its value is + // inferred from the count of the bottom blob and the remaining dimensions. + // For example, suppose we want to reshape a 2D blob "input" with shape 2 x 8: + // + // layer { + // type: "Reshape" bottom: "input" top: "output" + // reshape_param { ... } + // } + // + // If "input" is 2D with shape 2 x 8, then the following reshape_param + // specifications are all equivalent, producing a 3D blob "output" with shape + // 2 x 2 x 4: + // + // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 0 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 0 dim: 2 dim: -1 } } + // reshape_param { shape { dim: 0 dim:-1 dim: 4 } } + // + optional BlobShape shape = 1; + + // axis and num_axes control the portion of the bottom blob's shape that are + // replaced by (included in) the reshape. By default (axis == 0 and + // num_axes == -1), the entire bottom blob shape is included in the reshape, + // and hence the shape field must specify the entire output shape. + // + // axis may be non-zero to retain some portion of the beginning of the input + // shape (and may be negative to index from the end; e.g., -1 to begin the + // reshape after the last axis, including nothing in the reshape, + // -2 to include only the last axis, etc.). + // + // For example, suppose "input" is a 2D blob with shape 2 x 8. + // Then the following ReshapeLayer specifications are all equivalent, + // producing a blob "output" with shape 2 x 2 x 4: + // + // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 2 dim: 4 } axis: 1 } + // reshape_param { shape { dim: 2 dim: 4 } axis: -3 } + // + // num_axes specifies the extent of the reshape. + // If num_axes >= 0 (and axis >= 0), the reshape will be performed only on + // input axes in the range [axis, axis+num_axes]. + // num_axes may also be -1, the default, to include all remaining axes + // (starting from axis). + // + // For example, suppose "input" is a 2D blob with shape 2 x 8. + // Then the following ReshapeLayer specifications are equivalent, + // producing a blob "output" with shape 1 x 2 x 8. + // + // reshape_param { shape { dim: 1 dim: 2 dim: 8 } } + // reshape_param { shape { dim: 1 dim: 2 } num_axes: 1 } + // reshape_param { shape { dim: 1 } num_axes: 0 } + // + // On the other hand, these would produce output blob shape 2 x 1 x 8: + // + // reshape_param { shape { dim: 2 dim: 1 dim: 8 } } + // reshape_param { shape { dim: 1 } axis: 1 num_axes: 0 } + // + optional int32 axis = 2 [default = 0]; + optional int32 num_axes = 3 [default = -1]; +} + +// Message that stores parameters used by ROIPoolingLayer +message ROIPoolingParameter { + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pooled_h = 1 [default = 0]; // The pooled output height + optional uint32 pooled_w = 2 [default = 0]; // The pooled output width + // Multiplicative spatial scale factor to translate ROI coords from their + // input scale to the scale used when pooling + optional float spatial_scale = 3 [default = 1]; +} + +message ScaleParameter { + // The first axis of bottom[0] (the first input Blob) along which to apply + // bottom[1] (the second input Blob). May be negative to index from the end + // (e.g., -1 for the last axis). + // + // For example, if bottom[0] is 4D with shape 100x3x40x60, the output + // top[0] will have the same shape, and bottom[1] may have any of the + // following shapes (for the given value of axis): + // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 + // (axis == 1 == -3) 3; 3x40; 3x40x60 + // (axis == 2 == -2) 40; 40x60 + // (axis == 3 == -1) 60 + // Furthermore, bottom[1] may have the empty shape (regardless of the value of + // "axis") -- a scalar multiplier. + optional int32 axis = 1 [default = 1]; + + // (num_axes is ignored unless just one bottom is given and the scale is + // a learned parameter of the layer. Otherwise, num_axes is determined by the + // number of axes by the second bottom.) + // The number of axes of the input (bottom[0]) covered by the scale + // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. + // Set num_axes := 0, to multiply with a zero-axis Blob: a scalar. + optional int32 num_axes = 2 [default = 1]; + + // (filler is ignored unless just one bottom is given and the scale is + // a learned parameter of the layer.) + // The initialization for the learned scale parameter. + // Default is the unit (1) initialization, resulting in the ScaleLayer + // initially performing the identity operation. + optional FillerParameter filler = 3; + + // Whether to also learn a bias (equivalent to a ScaleLayer+BiasLayer, but + // may be more efficient). Initialized with bias_filler (defaults to 0). + optional bool bias_term = 4 [default = false]; + optional FillerParameter bias_filler = 5; +} + +message SigmoidParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +message SliceParameter { + // The axis along which to slice -- may be negative to index from the end + // (e.g., -1 for the last axis). + // By default, SliceLayer concatenates blobs along the "channels" axis (1). + optional int32 axis = 3 [default = 1]; + repeated uint32 slice_point = 2; + + // DEPRECATED: alias for "axis" -- does not support negative indexing. + optional uint32 slice_dim = 1 [default = 1]; +} + +message SmoothL1LossParameter { + // SmoothL1Loss(x) = + // 0.5 * (sigma * x) ** 2 -- if x < 1.0 / sigma / sigma + // |x| - 0.5 / sigma / sigma -- otherwise + optional float sigma = 1 [default = 1]; +} + +// Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer +message SoftmaxParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; + + // The axis along which to perform the softmax -- may be negative to index + // from the end (e.g., -1 for the last axis). + // Any other axes will be evaluated as independent softmaxes. + optional int32 axis = 2 [default = 1]; +} + +message TanHParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +// Message that stores parameters used by TileLayer +message TileParameter { + // The index of the axis to tile. + optional int32 axis = 1 [default = 1]; + + // The number of copies (tiles) of the blob to output. + optional int32 tiles = 2; +} + +// Message that stores parameters used by ThresholdLayer +message ThresholdParameter { + optional float threshold = 1 [default = 0]; // Strictly positive values +} + +message WindowDataParameter { + // Specify the data source. + optional string source = 1; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // Specify the batch size. + optional uint32 batch_size = 4; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 5 [default = 0]; + // Specify if we want to randomly mirror data. + optional bool mirror = 6 [default = false]; + // Foreground (object) overlap threshold + optional float fg_threshold = 7 [default = 0.5]; + // Background (non-object) overlap threshold + optional float bg_threshold = 8 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float fg_fraction = 9 [default = 0.25]; + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 context_pad = 10 [default = 0]; + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string crop_mode = 11 [default = "warp"]; + // cache_images: will load all images in memory for faster access + optional bool cache_images = 12 [default = false]; + // append root_folder to locate images + optional string root_folder = 13 [default = ""]; +} + +message SPPParameter { + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional uint32 pyramid_height = 1; + optional PoolMethod pool = 2 [default = MAX]; // The pooling method + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 6 [default = DEFAULT]; +} + +// DEPRECATED: use LayerParameter. +message V1LayerParameter { + repeated string bottom = 2; + repeated string top = 3; + optional string name = 4; + repeated NetStateRule include = 32; + repeated NetStateRule exclude = 33; + enum LayerType { + NONE = 0; + ABSVAL = 35; + ACCURACY = 1; + ARGMAX = 30; + BNLL = 2; + CONCAT = 3; + CONTRASTIVE_LOSS = 37; + CONVOLUTION = 4; + DATA = 5; + DECONVOLUTION = 39; + DROPOUT = 6; + DUMMY_DATA = 32; + EUCLIDEAN_LOSS = 7; + ELTWISE = 25; + EXP = 38; + FLATTEN = 8; + HDF5_DATA = 9; + HDF5_OUTPUT = 10; + HINGE_LOSS = 28; + IM2COL = 11; + IMAGE_DATA = 12; + INFOGAIN_LOSS = 13; + INNER_PRODUCT = 14; + LRN = 15; + MEMORY_DATA = 29; + MULTINOMIAL_LOGISTIC_LOSS = 16; + MVN = 34; + POOLING = 17; + POWER = 26; + RELU = 18; + SIGMOID = 19; + SIGMOID_CROSS_ENTROPY_LOSS = 27; + SILENCE = 36; + SOFTMAX = 20; + SOFTMAX_LOSS = 21; + SPLIT = 22; + SLICE = 33; + TANH = 23; + WINDOW_DATA = 24; + THRESHOLD = 31; + } + optional LayerType type = 5; + repeated BlobProto blobs = 6; + repeated string param = 1001; + repeated DimCheckMode blob_share_mode = 1002; + enum DimCheckMode { + STRICT = 0; + PERMISSIVE = 1; + } + repeated float blobs_lr = 7; + repeated float weight_decay = 8; + repeated float loss_weight = 35; + optional AccuracyParameter accuracy_param = 27; + optional ArgMaxParameter argmax_param = 23; + optional ConcatParameter concat_param = 9; + optional ContrastiveLossParameter contrastive_loss_param = 40; + optional ConvolutionParameter convolution_param = 10; + optional DataParameter data_param = 11; + optional DropoutParameter dropout_param = 12; + optional DummyDataParameter dummy_data_param = 26; + optional EltwiseParameter eltwise_param = 24; + optional ExpParameter exp_param = 41; + optional HDF5DataParameter hdf5_data_param = 13; + optional HDF5OutputParameter hdf5_output_param = 14; + optional HingeLossParameter hinge_loss_param = 29; + optional ImageDataParameter image_data_param = 15; + optional InfogainLossParameter infogain_loss_param = 16; + optional InnerProductParameter inner_product_param = 17; + optional LRNParameter lrn_param = 18; + optional MemoryDataParameter memory_data_param = 22; + optional MVNParameter mvn_param = 34; + optional PoolingParameter pooling_param = 19; + optional PowerParameter power_param = 21; + optional ReLUParameter relu_param = 30; + optional SigmoidParameter sigmoid_param = 38; + optional SoftmaxParameter softmax_param = 39; + optional SliceParameter slice_param = 31; + optional TanHParameter tanh_param = 37; + optional ThresholdParameter threshold_param = 25; + optional WindowDataParameter window_data_param = 20; + optional TransformationParameter transform_param = 36; + optional LossParameter loss_param = 42; + optional V0LayerParameter layer = 1; +} + +// DEPRECATED: V0LayerParameter is the old way of specifying layer parameters +// in Caffe. We keep this message type around for legacy support. +message V0LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the string to specify the layer type + + // Parameters to specify layers with inner products. + optional uint32 num_output = 3; // The number of outputs for the layer + optional bool biasterm = 4 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 5; // The filler for the weight + optional FillerParameter bias_filler = 6; // The filler for the bias + + optional uint32 pad = 7 [default = 0]; // The padding size + optional uint32 kernelsize = 8; // The kernel size + optional uint32 group = 9 [default = 1]; // The group size for group conv + optional uint32 stride = 10 [default = 1]; // The stride + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 11 [default = MAX]; // The pooling method + optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio + + optional uint32 local_size = 13 [default = 5]; // for local response norm + optional float alpha = 14 [default = 1.]; // for local response norm + optional float beta = 15 [default = 0.75]; // for local response norm + optional float k = 22 [default = 1.]; + + // For data layers, specify the data source + optional string source = 16; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 17 [default = 1]; + optional string meanfile = 18; + // For data layers, specify the batch size. + optional uint32 batchsize = 19; + // For data layers, specify if we would like to randomly crop an image. + optional uint32 cropsize = 20 [default = 0]; + // For data layers, specify if we want to randomly mirror data. + optional bool mirror = 21 [default = false]; + + // The blobs containing the numeric parameters of the layer + repeated BlobProto blobs = 50; + // The ratio that is multiplied on the global learning rate. If you want to + // set the learning ratio for one blob, you need to set it for all blobs. + repeated float blobs_lr = 51; + // The weight decay that is multiplied on the global weight decay. + repeated float weight_decay = 52; + + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 53 [default = 0]; + + // Fields related to detection (det_*) + // foreground (object) overlap threshold + optional float det_fg_threshold = 54 [default = 0.5]; + // background (non-object) overlap threshold + optional float det_bg_threshold = 55 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float det_fg_fraction = 56 [default = 0.25]; + + // optional bool OBSOLETE_can_clobber = 57 [default = true]; + + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 det_context_pad = 58 [default = 0]; + + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string det_crop_mode = 59 [default = "warp"]; + + // For ReshapeLayer, one needs to specify the new dimensions. + optional int32 new_num = 60 [default = 0]; + optional int32 new_channels = 61 [default = 0]; + optional int32 new_height = 62 [default = 0]; + optional int32 new_width = 63 [default = 0]; + + // Whether or not ImageLayer should shuffle the list of files at every epoch. + // It will also resize images if new_height or new_width are not zero. + optional bool shuffle_images = 64 [default = false]; + + // For ConcatLayer, one needs to specify the dimension for concatenation, and + // the other dimensions must be the same for all the bottom blobs. + // By default it will concatenate blobs along the channels dimension. + optional uint32 concat_dim = 65 [default = 1]; + + optional HDF5OutputParameter hdf5_output_param = 1001; +} + +message PReLUParameter { + // Parametric ReLU described in K. He et al, Delving Deep into Rectifiers: + // Surpassing Human-Level Performance on ImageNet Classification, 2015. + + // Initial value of a_i. Default is a_i=0.25 for all i. + optional FillerParameter filler = 1; + // Whether or not slope parameters are shared across channels. + optional bool channel_shared = 2 [default = false]; +} + +// Message that stores parameters used by ProposalLayer +message ProposalParameter { + optional float feat_stride = 1; + optional float base_size = 2; + optional float min_size = 3; + repeated float ratio = 4; + repeated float scale = 5; + optional int32 pre_nms_topn = 6; + optional int32 post_nms_topn = 7; + optional float nms_thresh = 8; +} + +// Message that stores parameters used by DetectionOutputLayer +//message DetectionOutputParameter { +// optional int32 num_classes = 1 [default = 21]; +// optional float nms_threshold = 2 [default = 0.3]; +// optional int32 top_k = 3; +// optional float confidence_threshold = 4 [default = 0.8]; +//} + +// Message that store parameters used by PriorBoxLayer +message PriorBoxParameter { + // Encode/decode type. + enum CodeType { + CORNER = 1; + CENTER_SIZE = 2; + CORNER_SIZE = 3; + } + // Minimum box size (in pixels). Required! + repeated float min_size = 1; + // Maximum box size (in pixels). Required! + repeated float max_size = 2; + // Various of aspect ratios. Duplicate ratios will be ignored. + // If none is provided, we use default ratio 1. + repeated float aspect_ratio = 3; + // If true, will flip each aspect ratio. + // For example, if there is aspect ratio "r", + // we will generate aspect ratio "1.0/r" as well. + optional bool flip = 4 [default = true]; + // If true, will clip the prior so that it is within [0, 1] + optional bool clip = 5 [default = false]; + // Variance for adjusting the prior bboxes. + repeated float variance = 6; + // By default, we calculate img_height, img_width, step_x, step_y based on + // bottom[0] (feat) and bottom[1] (img). Unless these values are explicitely + // provided. + // Explicitly provide the img_size. + optional uint32 img_size = 7; + // Either img_size or img_h/img_w should be specified; not both. + optional uint32 img_h = 8; + optional uint32 img_w = 9; + + // Explicitly provide the step size. + optional float step = 10; + // Either step or step_h/step_w should be specified; not both. + optional float step_h = 11; + optional float step_w = 12; + + // Offset to the top left corner of each cell. + optional float offset = 13 [default = 0.5]; +} + +// Message that stores parameters used by PermutetLayer +message PermuteParameter { + // The new orders of the axes of data. Notice it should be with + // in the same range as the input data, and it starts from 0. + // Do not provide repeated order. + repeated uint32 order = 1; +} + +message NormalizeParameter { + optional bool across_spatial = 1 [default = true]; + // Initial value of scale. Default is 1.0 for all + optional FillerParameter scale_filler = 2; + // Whether or not scale parameters are shared across channels. + optional bool channel_shared = 3 [default = true]; + // Epsilon for not dividing by zero while normalizing variance + optional float eps = 4 [default = 1e-10]; +} + +// needed by ssd +message SaveOutputParameter { + // Output directory. If not empty, we will save the results. + optional string output_directory = 1; + // Output name prefix. + optional string output_name_prefix = 2; + // Output format. + // VOC - PASCAL VOC output format. + // COCO - MS COCO output format. + optional string output_format = 3; + // If you want to output results, must also provide the following two files. + // Otherwise, we will ignore saving results. + // label map file. + optional string label_map_file = 4; + // A file which contains a list of names and sizes with same order + // of the input DB. The file is in the following format: + // name height width + // ... + optional string name_size_file = 5; + // Number of test images. It can be less than the lines specified in + // name_size_file. For example, when we only want to evaluate on part + // of the test images. + optional uint32 num_test_image = 6; + // The resize parameter used in saving the data. + // optional ResizeParameter resize_param = 7; +} + +message NonMaximumSuppressionParameter { + // Threshold to be used in nms. + optional float nms_threshold = 1 [default = 0.3]; + // Maximum number of results to be kept. + optional int32 top_k = 2; + // Parameter for adaptive nms. + optional float eta = 3 [default = 1.0]; +} + +message GeneralNmsParameter { + optional int32 post_top_k = 1 ; + optional float nms_threshold = 2 [default = 0]; + optional float iou_threshold_decay = 3 [default = 1.0]; + optional float coor_scale_factor = 4 [default = 1.0]; +} + +// Message that store parameters used by DetectionOutputLayer, ssd/fasterRcnn +message DetectionOutputParameter { + optional int32 num_classes = 1; + optional bool share_location = 2 [default = true]; + optional int32 background_label_id = 3 [default = 0]; + optional NonMaximumSuppressionParameter nms_param = 4; + optional SaveOutputParameter save_output_param = 5; + optional PriorBoxParameter.CodeType code_type = 6 [default = CENTER_SIZE]; + optional bool variance_encoded_in_target = 8 [default = true]; + optional int32 keep_top_k = 7; + optional float confidence_threshold = 9; + optional float nms_threshold = 13; + optional int32 top_k = 14; + optional int32 boxes = 15 [default = 1]; + optional bool relative = 17 [default = true]; + optional float objectness_threshold = 18 [default = 0.5]; + optional float class_threshold = 19 [default = 0.5]; + repeated float biases = 20; + optional GeneralNmsParameter general_nms_param = 21; +} +message PSROIPoolingParameter { + required float spatial_scale = 1; + required int32 output_dim = 2; // output channel number + required int32 group_size = 3; // number of groups to encode position-sensitive score maps +} +// Message that stores parameters used by FreespaceExtractLayer +message FreespaceExtractParameter { + optional float org_height = 1; +} + +// Message that stores parameters used by DetectpostprocessLayer +message PostprocessParameter { + optional float nms_thresh = 1 [default = 0.3]; + optional float conf_thresh = 2 [default = 0.5]; + optional uint32 post_nms_topn = 3 [default = 100]; + optional uint32 cls_num = 4 [default = 12]; + repeated float bbox_reg_weights = 5; +} + +// Message that stores parameters used by SpatialTransformLayer +message SpatialTransformParameter { + optional uint32 output_h = 1 [default = 0]; + optional uint32 output_w = 2 [default = 0]; + optional float border_value = 3 [default = 0]; + repeated float affine_transform = 4; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 15 [default = DEFAULT]; +} +message ROIAlignParameter { + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pooled_h = 1 [default = 0]; // The pooled output height + optional uint32 pooled_w = 2 [default = 0]; // The pooled output width + // Multiplicative spatial scale factor to translate ROI coords from their + // input scale to the scale used when pooling + optional float spatial_scale = 3 [default = 1]; + optional int32 sampling_ratio = 4 [default = -1]; +} + +message RegionParameter { + optional uint32 classes = 1 [default = 20]; //分类的种类 + optional uint32 coords = 2 [default = 4]; //box的坐标数 + optional uint32 boxes = 3 [default = 1]; //每个grid预测的boxes数 + optional uint32 softmax = 4 [default = 0]; + optional string softmax_tree = 5 [default = ""]; + optional uint32 background = 6 [default = 0]; +} +message ReorgParameter{ + optional uint32 stride = 2 [default = 2]; + optional bool reverse = 1 [default = false]; +} +message ReverseParameter{ + optional int32 axis = 1 [default = 1]; +} +message InterpParameter{ + optional int32 height = 1 [default = 0];//Height of output + optional int32 width = 2 [default = 0];//Width of output + optional int32 zoom_factor = 3 [default = 1];//zoom factor + optional int32 shrink_factor = 4 [default = 1];//shrink factor + optional int32 pad_beg = 5 [default = 0];//padding at begin of input + optional int32 pad_end = 6 [default = 0];//padding at end of input +} +message ShuffleChannelParameter{ + optional uint32 group = 1[default = 1]; // The number of group +} +message UpsampleParameter{ + optional int32 scale = 1[default = 1]; +} diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.cc new file mode 100644 index 00000000000..a1035eec523 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.h" +#include + +namespace mindspore { +namespace lite { +STATUS CaffeArgMaxParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + op->name = proto.name(); + std::unique_ptr attr(new schema::ArgMaxT()); + const caffe::ArgMaxParameter argmaxParam = proto.argmax_param(); + + int32_t axisType = 0; + int32_t axis = 0; + if (!argmaxParam.has_axis()) { + axisType = 2; + } else { + axisType = 1; + axis = (int64_t)argmaxParam.axis(); + if (axis == -1) { + // MS_LOGE("axis with -1 may lead to calculation errors when input less than 4 dims."); + return RET_ERROR; + } + } + + attr->axis = axis; + attr->axisType = axisType; + attr->outMaxValue = argmaxParam.out_max_val(); + attr->topK = argmaxParam.top_k(); + attr->keepDims = true; + + op->primitive = std::make_unique(); + op->primitive->value.value = attr.release(); + op->primitive->value.type = schema::PrimitiveType_ArgMax; + return RET_OK; +} + +CaffeNodeRegistrar g_caffeArgMaxParser("ArgMax", new CaffeArgMaxParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.h new file mode 100644 index 00000000000..b539c496875 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_argmax_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_ARGMAX_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_ARGMAX_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeArgMaxParser : public CaffeNodeParser { + public: + CaffeArgMaxParser() : CaffeNodeParser("argmax") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_ARGMAX_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc new file mode 100644 index 00000000000..2e8e78fa373 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.h" +#include "tools/common/tensor_util.h" + +#define CAFFE_BATCH_NORM_ESP_DEFAULT_FLOAT 0.00001 +#define CAFFE_BATCH_NORM_ESP_DEFAULT_DIFF_FLOAT 0.000000001 + +static const int CAFFE_BATCHNORMAL_BOTTOM_SIZE = 1; +static const int CAFFE_BATCHNORMAL_TOP_SIZE = 1; + +namespace mindspore { +namespace lite { +using STATUS = int; +STATUS CaffeBatchNormParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + op->name = proto.name(); + // caffe batch norm attr + std::unique_ptr attr(new FusedBatchNormT()); + const caffe::BatchNormParameter batchNormParam = proto.batch_norm_param(); + + // check bottom size + if (proto.bottom_size() != CAFFE_BATCHNORMAL_BOTTOM_SIZE) { + // MS_LOGE("Layer %s bottom numbers is error, it must be %d, but is %d", proto.name().c_str(), + // CAFFE_BATCHNORMAL_BOTTOM_SIZE, proto.bottom_size()); + return RET_ERROR; + } + + // check top size + if (proto.top_size() != CAFFE_BATCHNORMAL_TOP_SIZE) { + // MS_LOGE("Layer %s top numbers is error, it must be %d, but is %d", \ + proto.name().c_str(), CAFFE_BATCHNORMAL_TOP_SIZE, + // proto.top_size()); + return RET_ERROR; + } + + if (batchNormParam.has_eps()) { + if (fabs(CAFFE_BATCH_NORM_ESP_DEFAULT_FLOAT - batchNormParam.eps()) < CAFFE_BATCH_NORM_ESP_DEFAULT_DIFF_FLOAT) { + attr->epsilon = CAFFE_BATCH_NORM_ESP_DEFAULT_FLOAT; + } else { + auto tmpAuto = batchNormParam.eps(); + attr->epsilon = tmpAuto; + } + } else { + attr->epsilon = CAFFE_BATCH_NORM_ESP_DEFAULT_FLOAT; + } + + const float blob2Data = + (weight.blobs(2).double_data_size() > 0) ? weight.blobs(2).double_data(0) : weight.blobs(2).data(0); + const float scaleFactor = blob2Data == 0 ? 0 : 1 / blob2Data; + + // parse weight gamma + auto gamma = ConvertWeight(weight.blobs(0)); + if (gamma == nullptr) { + // MS_LOGE("Convert blobs(0) for layer %s failed", weight.name().c_str()); + return RET_ERROR; + } + + auto estimatedMean = reinterpret_cast(gamma->data.data()); + auto estimatedMeanShapeSize = GetShapeSize(*gamma); + for (size_t i = 0; i < estimatedMeanShapeSize; i++) { + estimatedMean[i] = estimatedMean[i] * scaleFactor; + } + estimatedMean = nullptr; + weightVec->push_back(gamma); + + // parse weight beta + auto beta = ConvertWeight(weight.blobs(1)); + if (beta == nullptr) { + // MS_LOGE("Convert blobs(1) for layer %s failed", weight.name().c_str()); + return RET_ERROR; + } + + auto estimatedVariance = reinterpret_cast(beta->data.data()); + size_t estimatedVarianceShapeSize = GetShapeSize(*beta); + for (size_t i = 0; i < estimatedVarianceShapeSize; i++) { + estimatedVariance[i] = estimatedVariance[i] * scaleFactor; + } + estimatedVariance = nullptr; + weightVec->push_back(beta); + + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; + op->primitive->value.value = attr.release(); + + return RET_OK; +} + +CaffeNodeRegistrar g_caffeBatchNormParser("BatchNorm", new CaffeBatchNormParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.h new file mode 100644 index 00000000000..aca0e41dbd1 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_batchnorm_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_BATCHNORM_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_BATCHNORM_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeBatchNormParser : public CaffeNodeParser { + public: + CaffeBatchNormParser() : CaffeNodeParser("batchnorm") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_BATCHNORM_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.cc new file mode 100644 index 00000000000..cd4ce95c7aa --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.h" + +const int32_t CONCAT_DEFAULT_AXIS = 1; + +namespace mindspore { +namespace lite { +STATUS CaffeConcatParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + op->name = proto.name(); + std::unique_ptr attr(new schema::ConcatT()); + const caffe::ConcatParameter concatParam = proto.concat_param(); + + if (concatParam.has_axis() && concatParam.has_concat_dim()) { + // MS_LOGE("Concat param in caffe have concat_dim and axis simultaneously,return fail"); + return RET_ERROR; + } + + if (concatParam.has_concat_dim()) { + // MS_LOGD("Concat dim , set axis:%d", concatParam.concat_dim()); + int32_t concat_dim_value = (int32_t)concatParam.concat_dim(); + + if (concat_dim_value < 0) { + // MS_LOGE("concat_dim value in model is smaller than 0:%d", concat_dim_value); + return RET_ERROR; + } + attr->axis = concat_dim_value; + } else if (concatParam.has_axis()) { + // MS_LOGD("axis , set axis:%d", concatParam.axis()); + int32_t tmpInt = (int32_t)concatParam.axis(); + attr->axis = tmpInt; + } else { + // MS_LOGD("default , set axis:%d", CONCAT_DEFAULT_AXIS); + attr->axis = CONCAT_DEFAULT_AXIS; + } + + attr->n = proto.bottom_size(); + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Concat; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +CaffeNodeRegistrar g_caffeConcatParser("Concat", new CaffeConcatParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.h new file mode 100644 index 00000000000..10ae7013d24 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_concat_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONCAT_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONCAT_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeConcatParser : public CaffeNodeParser { + public: + CaffeConcatParser() : CaffeNodeParser("concat") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONCAT_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.cc new file mode 100644 index 00000000000..b3bc5adee2d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.cc @@ -0,0 +1,218 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.h" + +const uint32_t PAD_DEFAULT_VALUE = 0; +const uint32_t STRIDE_DEFAULT_VALUE = 1; +const uint32_t DILATION_DEFAULT_VALUE = 1; +const int32_t SPATIAL_DIM_DEFAULT_SIZE = 2; +const uint32_t DEFAULT_CONV_GROUP = 1; +static const int CAFFE_CONV_BIAS_DIM_NUM = 1; + +namespace mindspore { +namespace lite { +STATUS CaffeConvBaseParser::ParsePads(const caffe::ConvolutionParameter &convParam, std::vector *pad) { + /** + * padUp = padH; + * padDown = padH; + * padLeft = padW; + * padRight = padW; + */ + if (convParam.has_pad_h() || convParam.has_pad_w()) { + if (convParam.pad_size() != 0) { + // MS_LOGE("Either pad or pad_h/w should be specified; not both"); + return RET_ERROR; + } + + if (!convParam.has_pad_h()) { + (*pad)[0] = PAD_DEFAULT_VALUE; + (*pad)[1] = PAD_DEFAULT_VALUE; + (*pad)[2] = convParam.pad_w(); + (*pad)[3] = convParam.pad_w(); + } else if (!convParam.has_pad_w()) { + (*pad)[0] = convParam.pad_h(); + (*pad)[1] = convParam.pad_h(); + (*pad)[2] = PAD_DEFAULT_VALUE; + (*pad)[3] = PAD_DEFAULT_VALUE; + } else { + (*pad)[0] = convParam.pad_h(); + (*pad)[1] = convParam.pad_h(); + (*pad)[2] = convParam.pad_w(); + (*pad)[3] = convParam.pad_w(); + } + } else { + // default 2D + const int num_pad_dims = convParam.pad_size(); + int num_spatial_dims = std::max(num_pad_dims, SPATIAL_DIM_DEFAULT_SIZE); + + std::vector vec; + for (int i = 0; i < num_spatial_dims; ++i) { + vec.push_back((num_pad_dims == 0) ? PAD_DEFAULT_VALUE : convParam.pad((num_pad_dims == 1) ? 0 : i)); + } + // default 2D + (*pad)[0] = vec[0]; + (*pad)[1] = vec[0]; + (*pad)[2] = vec[1]; + (*pad)[3] = vec[1]; + } + return RET_OK; +} + +STATUS CaffeConvBaseParser::ParseStrides(const caffe::ConvolutionParameter &convParam, std::vector *stride) { + if (convParam.has_stride_h() || convParam.has_stride_w()) { + if (convParam.stride_size() != 0) { + // MS_LOGE("Either stride or stride_h/w should be specified; not both"); + return RET_ERROR; + } + if (!convParam.has_stride_h() || !convParam.has_stride_w()) { + // MS_LOGE("stride_h/w must appear at the same time!"); + return RET_ERROR; + } + (*stride)[0] = convParam.stride_h(); + (*stride)[1] = convParam.stride_w(); + } else { + const int num_stride_dims = convParam.stride_size(); + int num_spatial_dims = std::max(num_stride_dims, SPATIAL_DIM_DEFAULT_SIZE); + + std::vector vec; + for (int i = 0; i < num_spatial_dims; ++i) { + vec.push_back((num_stride_dims == 0) ? STRIDE_DEFAULT_VALUE : convParam.stride((num_stride_dims == 1) ? 0 : i)); + } + // default 2D + (*stride)[0] = vec[0]; + (*stride)[1] = vec[1]; + } + return RET_OK; +} + +STATUS CaffeConvBaseParser::ParseDilations(const caffe::ConvolutionParameter &convParam, + std::vector *dilation) { + const int num_dilation_dims = convParam.dilation_size(); + int num_spatial_dims = std::max(num_dilation_dims, SPATIAL_DIM_DEFAULT_SIZE); + + std::vector vec; + for (int i = 0; i < num_spatial_dims; ++i) { + vec.push_back((num_dilation_dims == 0) ? DILATION_DEFAULT_VALUE + : convParam.dilation((num_dilation_dims == 1) ? 0 : i)); + } + // default 2D + (*dilation)[0] = vec[0]; + (*dilation)[1] = vec[1]; + + return RET_OK; +} + +STATUS CaffeConvBaseParser::ParseKernels(const caffe::ConvolutionParameter &convParam, std::vector *kernel) { + if (convParam.has_kernel_h() || convParam.has_kernel_w()) { + if (convParam.kernel_size_size() != 0) { + // MS_LOGE("Either kernel_size or kernel_h/w should be specified; not both.") + return RET_ERROR; + } + if (convParam.has_kernel_h() && convParam.has_kernel_w()) { + (*kernel)[0] = convParam.kernel_h(); + (*kernel)[1] = convParam.kernel_w(); + } else { + // MS_LOGE("kernel_h/w must appear at the same time!"); + return RET_ERROR; + } + } else if (convParam.kernel_size_size() != 0) { + int kernel_size = convParam.kernel_size_size(); + int num_spatial_dims = std::max(kernel_size, SPATIAL_DIM_DEFAULT_SIZE); + std::vector vec; + for (int i = 0; i < num_spatial_dims; i++) { + vec.push_back(convParam.kernel_size((kernel_size == 1) ? 0 : i)); + } + // default 2D + (*kernel)[0] = vec[0]; + (*kernel)[1] = vec[1]; + } else { + return RET_ERROR; + } + return RET_OK; +} + +int CaffeConvBaseParser::ParseGroup(const caffe::ConvolutionParameter &convParam, const std::string &layerType) { + // group default 1 + int group = 0; + if (convParam.has_group()) { + group = convParam.group(); + } else { + layerType == "ConvolutionDepthwise" ? (group = convParam.num_output()) : (group = DEFAULT_CONV_GROUP); + } + return group; +} + +int CaffeConvBaseParser::ParseChannelIn(const caffe::LayerParameter &proto, const int &group) { + int res = 0; + auto &weightBlob = proto.blobs(0); + if (weightBlob.has_shape()) { + res = weightBlob.shape().dim(1) * group; + } else { + // get shape information from Blob parameters(caffe proto v1) + if (proto.type() == "Deconvolution") { + res = weightBlob.num() * group; + } else { + res = weightBlob.channels() * group; + } + } + return res; +} + +int CaffeConvBaseParser::ParseChannelOut(const caffe::ConvolutionParameter &convParam) { + if (!convParam.has_num_output()) { + // MS_LOGE("Parse num_output for failed."); + } + return convParam.num_output(); +} + +STATUS CaffeConvBaseParser::ParseWeight(const caffe::LayerParameter &weight, + std::vector *weightVec) { + // Layer must have Filter + if (weight.blobs_size() == 0) { + // MS_LOGE("No filter data in layer %s", weight.name().c_str()); + return RET_ERROR; + } + + auto filter = ConvertWeight(weight.blobs(0)); + if (filter == nullptr) { + // MS_LOGE("Convert weight for layer %s failed", weight.name().c_str()); + return RET_ERROR; + } + weightVec->push_back(filter); + + // parse bias + const caffe::ConvolutionParameter convParam = weight.convolution_param(); + if (convParam.bias_term() && weight.blobs_size() > 1) { + auto bias = ConvertWeight(weight.blobs(1)); + if (bias == nullptr) { + // MS_LOGE("Convert bias for layer %s failed", weight.name().c_str()); + return RET_ERROR; + } + + std::vector shape = bias->dims; + if (shape.size() != CAFFE_CONV_BIAS_DIM_NUM) { + // MS_LOGE("Bias dim-num of layer %s is not supported"); + return RET_ERROR; + } + weightVec->push_back(bias); + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.h new file mode 100644 index 00000000000..d1e2886879b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONV_BASE_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONV_BASE_PARSER_H_ + +#include +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeConvBaseParser { + public: + CaffeConvBaseParser() {} + + virtual ~CaffeConvBaseParser() {} + + STATUS ParsePads(const caffe::ConvolutionParameter &conv_param, std::vector *pad); + + STATUS ParseStrides(const caffe::ConvolutionParameter &conv_param, std::vector *stride); + + STATUS ParseDilations(const caffe::ConvolutionParameter &conv_param, std::vector *dilation); + + STATUS ParseKernels(const caffe::ConvolutionParameter &conv_param, std::vector *kernel); + + int ParseGroup(const caffe::ConvolutionParameter &convParam, const std::string &layerType); + + int ParseChannelOut(const caffe::ConvolutionParameter &convParam); + + int ParseChannelIn(const caffe::LayerParameter &proto, const int &group); + + STATUS ParseWeight(const caffe::LayerParameter &weight, std::vector *weightVec); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONV_BASE_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_converter.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_converter.cc new file mode 100644 index 00000000000..16056fa39d0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_converter.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/caffe/caffe_converter.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h" + +namespace mindspore { +namespace lite { +CaffeConverter::CaffeConverter() { + modelParser = new CaffeModelParser(); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_converter.h b/mindspore/lite/tools/converter/parser/caffe/caffe_converter.h new file mode 100644 index 00000000000..889c5afefd2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_converter.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONVERTER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONVERTER_H_ + +#include +#include +#include "mindspore/lite/tools/converter/converter.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h" +#include "mindspore/lite/tools/converter/graphdef_transform.h" + +namespace mindspore::lite { +class CaffeConverter : public Converter { + public: + CaffeConverter(); + + ~CaffeConverter() override = default; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONVERTER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.cc new file mode 100644 index 00000000000..5c622f13d0d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.h" + +namespace mindspore { +namespace lite { +void CaffeConvolutionParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr) { + if (attr == nullptr || attr->group == 1 || attr->group != attr->channelOut) { + return; + } + std::unique_ptr depthwiseConv2DParam(new schema::DepthwiseConv2DT()); + if (depthwiseConv2DParam == nullptr) { + // MS_LOGW("new DepthwiseConv2DT failed"); + return; + } + depthwiseConv2DParam->format = attr->format; + depthwiseConv2DParam->channelIn = attr->channelIn; + depthwiseConv2DParam->channelMultiplier = attr->channelOut / attr->channelIn; + depthwiseConv2DParam->kernelW = attr->kernelW; + depthwiseConv2DParam->kernelH = attr->kernelH; + depthwiseConv2DParam->strideW = attr->strideW; + depthwiseConv2DParam->strideH = attr->strideH; + depthwiseConv2DParam->padMode = attr->padMode; + depthwiseConv2DParam->padUp = attr->padUp; + depthwiseConv2DParam->padDown = attr->padDown; + depthwiseConv2DParam->padLeft = attr->padLeft; + depthwiseConv2DParam->padRight = attr->padRight; + depthwiseConv2DParam->dilateW = attr->dilateW; + depthwiseConv2DParam->dilateH = attr->dilateH; + depthwiseConv2DParam->hasBias = attr->hasBias; + depthwiseConv2DParam->activationType = attr->activationType; + delete attr; + op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + op->primitive->value.value = depthwiseConv2DParam.release(); +} + +STATUS CaffeConvolutionParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, + schema::CNodeT *op, std::vector *weightVec) { + op->name = proto.name(); + schema::Conv2DT *attr = new schema::Conv2DT(); + + attr->format = schema::Format_NCHW; + const caffe::ConvolutionParameter convParam = proto.convolution_param(); + + CaffeConvBaseParser convParser; + // parse pad + std::vector pad(4, 0); + auto status = convParser.ParsePads(convParam, &pad); + if (status != RET_OK) { + // MS_LOGE("ParsePads for %s failed", proto.name().c_str()); + } + attr->padUp = pad[0]; + attr->padDown = pad[1]; + attr->padLeft = pad[2]; + attr->padRight = pad[3]; + + // parse stride + std::vector stride(2, 0); + status = convParser.ParseStrides(convParam, &stride); + if (status != RET_OK) { + // MS_LOGE("ParseStrides for %s failed", proto.name().c_str()); + } + attr->strideH = stride[0]; + attr->strideW = stride[1]; + + // parse dilation + std::vector dilation(2, 0); + status = convParser.ParseDilations(convParam, &dilation); + if (status != RET_OK) { + // MS_LOGE("ParseDilations for %s failed", proto.name().c_str()); + } + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + // parse kernel + std::vector kernel(2, 0); + status = convParser.ParseKernels(convParam, &kernel); + if (status != RET_OK) { + // MS_LOGE("ParseKernels for %s failed", proto.name().c_str()); + } + attr->kernelH = kernel[0]; + attr->kernelW = kernel[1]; + + attr->hasBias = convParam.bias_term(); + attr->group = convParser.ParseGroup(convParam, proto.type()); + attr->channelOut = convParser.ParseChannelOut(convParam); + attr->channelIn = convParser.ParseChannelIn(weight, attr->group); + attr->padMode = schema::PadMode_CAFFE; + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Conv2D; + op->primitive->value.value = attr; + + ParseGroupConvolution(op, attr); + status = convParser.ParseWeight(weight, weightVec); + if (status != RET_OK) { + // MS_LOGE("ParseWeight for %s failed", proto.name().c_str()); + } + return RET_OK; +} + +CaffeNodeRegistrar g_caffeConvolutionParser("Convolution", new CaffeConvolutionParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.h new file mode 100644 index 00000000000..297c6242ba1 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_convolution_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONVOLUTION_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONVOLUTION_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.h" + +namespace mindspore { +namespace lite { +class CaffeConvolutionParser : public CaffeNodeParser { + public: + CaffeConvolutionParser() : CaffeNodeParser("convolution") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; + + private: + void ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CONVOLUTION_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.cc new file mode 100644 index 00000000000..106fb3ad71e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.h" + +const int32_t CROP_AXIS = 2; + +namespace mindspore { +namespace lite { +STATUS CaffeCropParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + std::unique_ptr attr(new schema::CropT()); + if (!proto.has_crop_param()) { + attr->axis = CROP_AXIS; + std::vector offsets(2, 0); + attr->offsets = offsets; + } else { + const caffe::CropParameter cropParam = proto.crop_param(); + + if (cropParam.has_axis()) { + if (cropParam.axis() == -1) { + // MS_LOGW("axis with -1 may lead to calculation errors when input less than 4 dims."); + } + attr->axis = cropParam.axis(); + } else { + attr->axis = CROP_AXIS; + } + + if (cropParam.offset_size() != 0) { + std::vector offsets; + for (int i = 0; i < cropParam.offset_size(); i++) { + offsets.push_back(cropParam.offset(i)); + } + attr->offsets = offsets; + } + } + op->primitive = std::make_unique(); + op->primitive->value.value = attr.release(); + op->primitive->value.type = schema::PrimitiveType_Crop; + return RET_OK; +} + +CaffeNodeRegistrar g_caffeCropParser("Crop", new CaffeCropParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.h new file mode 100644 index 00000000000..7de30b5cc0f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_crop_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CROP_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CROP_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeCropParser : public CaffeNodeParser { + public: + CaffeCropParser() : CaffeNodeParser("crop") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_CROP_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc new file mode 100644 index 00000000000..be9682f2fd9 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc @@ -0,0 +1,118 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.h" + +namespace mindspore { +namespace lite { +void CaffeDeconvolutionParser::ParseGroupDeconvolution(schema::CNodeT *op, schema::DeConv2DT *attr) { + if (attr == nullptr || attr->group == 1 || attr->group != attr->channelIn) { + return; + } + + std::unique_ptr deDepthwiseConv2DParam(new schema::DeDepthwiseConv2DT()); + if (deDepthwiseConv2DParam == nullptr) { + // MS_LOGW("new DeDepthwiseConv2DT failed"); + return; + } + deDepthwiseConv2DParam->format = attr->format; + deDepthwiseConv2DParam->channelIn = attr->channelOut; + deDepthwiseConv2DParam->channelMultiplier = attr->channelIn / attr->channelOut; + deDepthwiseConv2DParam->kernelW = attr->kernelW; + deDepthwiseConv2DParam->kernelH = attr->kernelH; + deDepthwiseConv2DParam->strideW = attr->strideW; + deDepthwiseConv2DParam->strideH = attr->strideH; + deDepthwiseConv2DParam->padMode = attr->padMode; + deDepthwiseConv2DParam->padUp = attr->padUp; + deDepthwiseConv2DParam->padDown = attr->padDown; + deDepthwiseConv2DParam->padLeft = attr->padLeft; + deDepthwiseConv2DParam->padRight = attr->padRight; + deDepthwiseConv2DParam->dilateW = attr->dilateW; + deDepthwiseConv2DParam->dilateH = attr->dilateH; + deDepthwiseConv2DParam->hasBias = attr->hasBias; + deDepthwiseConv2DParam->activationType = attr->activationType; + delete attr; + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D; + op->primitive->value.value = deDepthwiseConv2DParam.release(); +} +STATUS CaffeDeconvolutionParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, + schema::CNodeT *op, std::vector *weightVec) { + op->name = proto.name(); + schema::DeConv2DT *attr = new schema::DeConv2DT(); + attr->format = schema::Format_NCHW; + const caffe::ConvolutionParameter convParam = proto.convolution_param(); + + CaffeConvBaseParser convParser; + // parse pad + std::vector pad(4, 0); + auto status = convParser.ParsePads(convParam, &pad); + if (status != RET_OK) { + // MS_LOGE("ParsePads for %s failed", proto.name().c_str()); + } + attr->padUp = pad[0]; + attr->padDown = pad[1]; + attr->padLeft = pad[2]; + attr->padRight = pad[3]; + + // parse stride + std::vector stride(2, 0); + status = convParser.ParseStrides(convParam, &stride); + if (status != RET_OK) { + // MS_LOGE("ParseStrides for %s failed", proto.name().c_str()); + } + attr->strideH = stride[0]; + attr->strideW = stride[1]; + + // parse dilation + std::vector dilation(2, 0); + status = convParser.ParseDilations(convParam, &dilation); + if (status != RET_OK) { + // MS_LOGE("ParseDilations for %s failed", proto.name().c_str()); + } + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + // parse kernel + std::vector kernel(2, 0); + status = convParser.ParseKernels(convParam, &kernel); + if (status != RET_OK) { + // MS_LOGE("ParseKernels for %s failed", proto.name().c_str()); + } + attr->kernelH = kernel[0]; + attr->kernelW = kernel[1]; + + attr->hasBias = convParam.bias_term(); + attr->group = convParser.ParseGroup(convParam, proto.type()); + attr->channelOut = convParser.ParseChannelOut(convParam); + attr->channelIn = convParser.ParseChannelIn(weight, attr->group); + attr->padMode = schema::PadMode_CAFFE; + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DeConv2D; + op->primitive->value.value = attr; + ParseGroupDeconvolution(op, attr); + status = convParser.ParseWeight(weight, weightVec); + if (status != RET_OK) { + // MS_LOGE("ParseWeight for %s failed", proto.name().c_str()); + } + return RET_OK; +} + +CaffeNodeRegistrar g_caffeDeconvolutionParser("Deconvolution", new CaffeDeconvolutionParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.h new file mode 100644 index 00000000000..834dd0a9c6f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_deconvolution_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_DECONVOLUTION_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_DECONVOLUTION_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_conv_base_parser.h" + +namespace mindspore { +namespace lite { +class CaffeDeconvolutionParser : public CaffeNodeParser { + public: + CaffeDeconvolutionParser() : CaffeNodeParser("deconvolution") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; + + private: + void ParseGroupDeconvolution(schema::CNodeT *op, schema::DeConv2DT *attr); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_DECONVOLUTION_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.cc new file mode 100644 index 00000000000..c750a2c5d86 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.h" + +const int ELTWISE_MIN_INPUT_SIZE = 2; + +namespace mindspore { +namespace lite { +STATUS CaffeEltwiseParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, + schema::CNodeT *op, std::vector *weightVec) { + std::unique_ptr attr(new schema::EltwiseT()); + if (proto.bottom_size() < ELTWISE_MIN_INPUT_SIZE) { + // MS_LOGE("Eltwise Op '%s' need at least 2 inputs,but input size is %d", proto.name().c_str(), + // proto.bottom_size()); + return RET_ERROR; + } + + const caffe::EltwiseParameter eltwiseParam = proto.eltwise_param(); + + if (eltwiseParam.coeff_size() != 0 && eltwiseParam.coeff_size() != proto.bottom_size()) { + // MS_LOGE("Coeff size(%d) check fail, Eltwise Layer takes one coefficient per bottom blob.", + // eltwiseParam.coeff_size()); + return RET_PARAM_INVALID; + } + + if (eltwiseParam.operation() == caffe::EltwiseParameter::PROD && eltwiseParam.coeff_size() != 0) { + // MS_LOGE("Eltwise layer only takes coefficients for summation."); + return RET_ERROR; + } + + if (proto.has_eltwise_param() && eltwiseParam.has_operation()) { + switch (eltwiseParam.operation()) { + case caffe::EltwiseParameter::PROD: + attr->mode = schema::EltwiseMode_PROD; + break; + case caffe::EltwiseParameter::SUM: + attr->mode = schema::EltwiseMode_SUM; + break; + case caffe::EltwiseParameter::MAX: + attr->mode = schema::EltwiseMode_MAXIMUM; + break; + default: + // MS_LOGE("Eltwise parse params fail, unsupported opration %d.", eltwiseParam.operation()); + return RET_PARAM_INVALID; + } + } else { + attr->mode = schema::EltwiseMode_SUM; + } + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Eltwise; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +CaffeNodeRegistrar g_caffeEltwiseParser("Eltwise", new CaffeEltwiseParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.h new file mode 100644 index 00000000000..efae240ecd0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_eltwise_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_ELTWISE_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_ELTWISE_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeEltwiseParser : public CaffeNodeParser { + public: + CaffeEltwiseParser() : CaffeNodeParser("eltwise") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_ELTWISE_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.cc new file mode 100644 index 00000000000..81c0a71c0d5 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.h" + +namespace mindspore { +namespace lite { +STATUS CaffeInnerProductParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, + schema::CNodeT *op, std::vector *weightVec) { + const caffe::InnerProductParameter innerProductParam = proto.inner_product_param(); + std::unique_ptr attr(new schema::FullConnectionT()); + + if (!innerProductParam.has_num_output()) { + // MS_LOGE("InnerProduct Parse num_output for %s failed.", proto.name().c_str()); + return RET_ERROR; + } + + if (innerProductParam.axis() == 1) { + attr->axis = 1; + } else { + // MS_LOGE("InnerProduct Parse axis only support default 1, but actually %d.", innerProductParam.axis()); + return RET_ERROR; + } + + if (innerProductParam.bias_term()) { + attr->hasBias = true; + } + + // parse weight + if (weight.blobs_size() == 0) { + // MS_LOGE("InnerProduct No filter data in layer %s", weight.name().c_str()); + return RET_ERROR; + } + + // parse filter + auto filter = ConvertWeight(weight.blobs(0)); + if (filter == nullptr) { + // MS_LOGE("InnerProduct parse weight for layer %s failed", weight.name().c_str()); + return RET_ERROR; + } + weightVec->push_back(filter); + + // parse bias + if (innerProductParam.bias_term() && weight.blobs_size() > 1) { + auto bias = ConvertWeight(weight.blobs(1)); + if (bias == nullptr) { + // MS_LOGE("InnerProduct parse bias for layer %s failed", weight.name().c_str()); + return RET_ERROR; + } + weightVec->push_back(bias); + } + op->primitive = std::make_unique(); + op->primitive->value.value = attr.release(); + op->primitive->value.type = schema::PrimitiveType_FullConnection; + return RET_OK; +} + +CaffeNodeRegistrar g_caffeInnerProductParser("InnerProduct", new CaffeInnerProductParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.h new file mode 100644 index 00000000000..548c4535f78 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_innerproduct_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INNERPRODUCT_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INNERPRODUCT_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeInnerProductParser : public CaffeNodeParser { + public: + CaffeInnerProductParser() : CaffeNodeParser("innerproduct") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INNERPRODUCT_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_inspector.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_inspector.cc new file mode 100644 index 00000000000..18c4337b8fd --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_inspector.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/caffe/caffe_inspector.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace lite { +STATUS CaffeInspector::InspectModel(const caffe::NetParameter &proto) { + net = proto; + + if (proto.layer_size() == 0) { + // MS_LOGE("net layer num is zero, prototxt file may be invalid."); + return RET_ERROR; + } + + ParseInput(); + + SetTopsAndBottoms(); + + FindInputAndOutput(); +} + +STATUS CaffeInspector::ParseInput() { + if (net.input_size() > 0) { + // MS_LOGI("This net exist input."); + for (int i = 0; i < net.input_size(); i++) { + graphInput.insert(net.input(i)); + } + } + return RET_OK; +} + +STATUS CaffeInspector::FindInputAndOutput() { + for (auto iter : layerBottoms) { + if (layerTops.find(iter) == layerTops.end()) { + graphInput.insert(iter); + } + } + for (auto iter : layerTops) { + if (layerBottoms.find(iter) == layerBottoms.end()) { + graphOutput.insert(iter); + } + } +} + +STATUS CaffeInspector::SetTopsAndBottoms() { + for (int32_t i = 0; i < net.layer_size(); i++) { + caffe::LayerParameter &layer = const_cast(net.layer(i)); + if (layer.top_size() == 1 && layer.bottom_size() == 1 && layer.top(0) == layer.bottom(0)) { + continue; + } + if (layer.top_size() == 1 && layer.bottom_size() == 0) { + graphInput.insert(layer.top(0)); + } + for (int j = 0; j < layer.top_size(); j++) { + layerTops.insert(layer.top(j)); + } + for (int j = 0; j < layer.bottom_size(); j++) { + layerBottoms.insert(layer.bottom(j)); + } + } +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_inspector.h b/mindspore/lite/tools/converter/parser/caffe/caffe_inspector.h new file mode 100644 index 00000000000..94bda8ddec5 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_inspector.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INSPECTOR_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INSPECTOR_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/caffe/caffe.pb.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +class CaffeInspector { + public: + CaffeInspector() = default; + + STATUS InspectModel(const caffe::NetParameter &proto); + STATUS ParseInput(); + STATUS FindInputAndOutput(); + STATUS SetTopsAndBottoms(); + + std::set GetGraphInput() { return graphInput; } + std::set GetGraphOutput() { return graphOutput; } + + private: + caffe::NetParameter net; + + std::set layerTops; + std::set layerBottoms; + + std::set graphInput; + std::set graphOutput; +}; + +using CaffeInspectorPtr = std::shared_ptr; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INSPECTOR_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.cc new file mode 100644 index 00000000000..2ade9ec54db --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.h" + +namespace mindspore { +namespace lite { +STATUS CaffeInterpParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, + schema::CNodeT *op, std::vector *weightVec) { + std::unique_ptr attr(new schema::ResizeT()); + const caffe::InterpParameter interpParam = proto.interp_param(); + + if (interpParam.has_height()) { + int64_t height = interpParam.height(); + if (height < 0) { + // MS_LOGE("Interp height must be > 0"); + return RET_ERROR; + } + attr->newHeight = height; + } + + if (interpParam.has_width()) { + int64_t width = interpParam.width(); + if (width < 0) { + // MS_LOGE("Interp width must be > 0"); + return RET_ERROR; + } + attr->newWidth = width; + } + + attr->alignCorners = true; + attr->method = schema::ResizeMethod_BILINEAR; + + op->name = proto.name(); + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Resize; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +CaffeNodeRegistrar g_caffeInterpParser("Interp", new CaffeInterpParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.h new file mode 100644 index 00000000000..675cc9ff883 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_interp_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INTERP_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INTERP_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeInterpParser : public CaffeNodeParser { + public: + CaffeInterpParser() : CaffeNodeParser("Interp") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_INTERP_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc new file mode 100755 index 00000000000..8992d2da1f6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -0,0 +1,307 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h" +#include +#include +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_inspector.h" +#include "tools/common/graph_util.h" + +namespace mindspore { +namespace lite { +CaffeModelParser::CaffeModelParser() {} + +CaffeModelParser::~CaffeModelParser() {} + +const std::set CaffeModelParser::skipedLayerType = {"Dropout"}; + +schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const std::string &weightFile) { + std::unique_ptr graph(new schema::MetaGraphT()); + + if (ValidateFileStr(modelFile, ".prototxt") != RET_OK) { + MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt"; + return nullptr; + } + + if (weightFile.empty()) { + MS_LOG(ERROR) << "INPUT MISSING: weightFile is necessary"; + return nullptr; + } + + if (ValidateFileStr(weightFile, ".caffemodel") != RET_OK) { + MS_LOG(ERROR) << "INPUT ILLEGAL: weightFile must be *.caffemodel"; + return nullptr; + } + + std::unique_ptr subGraphDef(new schema::MetaGraphT()); + TensorCache tensorCache; + + caffe::NetParameter proto; + if (ReadProtoFromText((const char *)modelFile.c_str(), &proto) != RET_OK) { + MS_LOG(ERROR) << "Read prototxt file failed, model path: " << modelFile; + return nullptr; + } + subGraphDef->name = proto.name(); + + caffe::NetParameter weight; + if (ReadProtoFromBinaryFile((const char *)weightFile.c_str(), &weight) != RET_OK) { + MS_LOG(ERROR) << "Read caffemodel file failed, model path: " << weightFile; + return nullptr; + } + + auto status = GetModelInput(proto, &tensorCache); + if (status != RET_OK) { + MS_LOG(ERROR) << "GetModelInput failed " << status; + return nullptr; + } + + status = ParseLayer(proto, weight, &tensorCache, subGraphDef.get()); + if (status != RET_OK) { + MS_LOG(ERROR) << "ParseLayer failed " << status; + return nullptr; + } + + // set inputTensor index and outputTensor index for the whole graph + status = SetGraphTensorIndex(proto, &tensorCache, subGraphDef.get()); + if (status != RET_OK) { + MS_LOG(ERROR) << "Set inputTensor index and outputTensor index for graph failed!"; + return nullptr; + } + subGraphDef->name = GetModelName(modelFile); + // set all tensors to graph + SetAllTensors(tensorCache, subGraphDef.get()); + graph = move(subGraphDef); + + ConvertCaffeBatchNorm(graph.get()); + + return graph.release(); +// return Fb2Anf(graph.release()); +} + +STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, + schema::CNodeT *op, + TensorCache *tensorCache) { + for (int i = 0; i < layer.bottom_size(); i++) { + int index = tensorCache->FindTensor(layer.bottom(i)); + if (index >= 0) { + op->inputIndex.emplace_back(index); + } else { + // MS_LOGE("Can't find input layer for %s.", layer.name().c_str()); + return RET_ERROR; + } + } + return RET_OK; +} + +STATUS CaffeModelParser::SetOpOutputIdx(const caffe::LayerParameter &layer, + schema::CNodeT *op, + TensorCache *tensorCache) { + for (int i = 0; i < layer.top_size(); i++) { + std::unique_ptr msTensor(new schema::TensorT()); + op->outputIndex.emplace_back(tensorCache->AddTensor(layer.top(i), msTensor.release(), OP_OUTPUT)); + } + return RET_OK; +} + +STATUS CaffeModelParser::SetWeightTensor(const std::vector &weightVec, schema::CNodeT *op, + TensorCache *tensorCache) { + for (auto iter : weightVec) { + op->inputIndex.emplace_back(tensorCache->AddTensor("Weight", iter, CONST)); + } + return RET_OK; +} + +STATUS CaffeModelParser::SetAllTensors(const TensorCache &tensorCache, schema::MetaGraphT *subGraphDef) { + std::vector tensors = tensorCache.GetCachedTensor(); + for (auto iter : tensors) { + std::unique_ptr temp(iter); + subGraphDef->allTensors.emplace_back(move(temp)); + } + return RET_OK; +} + +STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, TensorCache *tensorCache, + schema::MetaGraphT *subGraphDef) { + CaffeInspector caffeInspector; + caffeInspector.InspectModel(proto); + for (auto iter : caffeInspector.GetGraphInput()) { + int index = tensorCache->FindTensor(iter); + if (index >= 0) { + subGraphDef->inputIndex.emplace_back(index); + } else { + // MS_LOGE("Can't find input tensor layer for graph."); + return RET_ERROR; + } + } + + for (auto iter : caffeInspector.GetGraphOutput()) { + int index = tensorCache->FindTensor(iter); + if (index >= 0) { + subGraphDef->outputIndex.emplace_back(index); + } else { + // MS_LOGE("Can't find output tensor layer for graph."); + return RET_ERROR; + } + } + return RET_OK; +} + +STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caffe::NetParameter &weight, + TensorCache *tensorCache, schema::MetaGraphT *subGraphDef) { + for (int i = 0; i < proto.layer_size(); i++) { + auto layer = proto.layer(i); + + caffe::LayerParameter layerP; + for (int j = 0; j < weight.layer_size(); j++) { + auto tempLayer = weight.layer(j); + if (tempLayer.name() == layer.name()) { + layerP = tempLayer; + break; + } + } + // todo y00520784 : layer.input_param().shape(0) + if (layer.type() == "Input") { + std::unique_ptr msTensor(new schema::TensorT()); + for (int j = 0; j < layer.input_param().shape(0).dim_size(); j++) { + msTensor->dims.push_back(layer.input_param().shape(0).dim(j)); + } + msTensor->nodeType = schema::NodeType_ValueNode; + msTensor->refCount = 1; + msTensor->dataType = kNumberTypeFloat32; + tensorCache->AddTensor(layer.top(0), msTensor.release(), GRAPH_INPUT); + } else { + if (skipedLayerType.find(layer.type()) != skipedLayerType.end()) { + MS_LOG(INFO) << "Skip layer " << layer.name(); + continue; + } + + std::unique_ptr op(new schema::CNodeT()); + op->name = layer.name(); + + // set op input index + auto status = SetOpInputIdx(layer, op.get(), tensorCache); + if (status != RET_OK) { + MS_LOG(ERROR) << "Set Op " << layer.name() << " Input Index Failed!"; + return status; + } + + auto nodeParser = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type().c_str()); + if (nodeParser == nullptr) { + MS_LOG(ERROR) << "Don't support type " << layer.type() << ". for caffe op " << layer.name(); + return RET_ERROR; + } + + std::vector weightVec; + status = nodeParser->Parse(layer, layerP, op.get(), &weightVec); + if (status != RET_OK) { + MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!"; + return status; + } + // set op weight tensor to tensorcache + SetWeightTensor(weightVec, op.get(), tensorCache); + + // set op output index + status = SetOpOutputIdx(layer, op.get(), tensorCache); + if (status != RET_OK) { + MS_LOG(ERROR) << "Set Op " << layer.name() << " Output Index Failed!"; + return status; + } + + // op->fmkType = FmkType_CAFFE; + subGraphDef->nodes.emplace_back(move(op)); + } + } + return RET_OK; +} + +STATUS CaffeModelParser::GetModelInput(const caffe::NetParameter &proto, TensorCache *tensorCache) { + for (int i = 0; i < proto.input_size(); i++) { + if (proto.input_dim_size() <= 0) { + continue; + } + std::unique_ptr msTensor(new schema::TensorT()); + for (int j = 0; j < proto.input_dim_size(); j++) { + msTensor->dims.push_back(proto.input_dim(j)); + } + msTensor->refCount = schema::NodeType_ValueNode; + msTensor->dataType = kNumberTypeFloat32; + tensorCache->AddTensor(proto.input(i), msTensor.release(), GRAPH_INPUT); + } + + for (int i = 0; i < proto.input_shape_size(); i++) { + auto shape = proto.input_shape(i); + std::unique_ptr msTensor(new schema::TensorT()); + for (int j = 0; j < shape.dim_size(); j++) { + msTensor->dims.push_back(shape.dim(j)); + } + msTensor->refCount = schema::NodeType_ValueNode; + msTensor->dataType = kNumberTypeFloat32; + tensorCache->AddTensor(proto.input(i), msTensor.release(), GRAPH_INPUT); + } + return RET_OK; +} + +void CaffeModelParser::ConvertCaffeBatchNorm(schema::MetaGraphT *meta_graph) { + MS_ASSERT(meta_graph != nullptr); + auto &nodes = meta_graph->nodes; + for (auto &node : nodes) { + if (node->primitive->value.type != schema::PrimitiveType_FusedBatchNorm) { + continue; + } + MS_ASSERT(node->inputIndex.size() == 2); + MS_ASSERT(node->inputIndex.back() < meta_graph->allTensors.size()); + auto &meanTensor = meta_graph->allTensors.at(node->inputIndex.back()); + MS_ASSERT(nullptr != meanTensor); + auto shape = meanTensor->dims; + auto shapeSize = GetShapeSize(shape); + + auto scaleTensor = std::make_unique(); + scaleTensor->dims = shape; + scaleTensor->nodeType = NodeType_ValueNode; + scaleTensor->refCount = 1; + scaleTensor->format = schema::Format_NUM_OF_FORMAT; + scaleTensor->dataType = TypeId::kNumberTypeFloat32; + scaleTensor->data.resize(shapeSize * sizeof(float)); + auto scaleData = reinterpret_cast(scaleTensor->data.data()); + for (size_t i = 0 ; i < shapeSize; i++) { + scaleData[i] = 1; + } + + auto biasTensor = std::make_unique(); + biasTensor->dims = shape; + biasTensor->nodeType = NodeType_ValueNode; + biasTensor->refCount = 1; + biasTensor->format = schema::Format_NUM_OF_FORMAT; + biasTensor->dataType = TypeId::kNumberTypeInt32; + biasTensor->data.resize(shapeSize * sizeof(int32_t)); + auto biasData = reinterpret_cast(biasTensor->data.data()); + for (size_t i = 0 ; i < shapeSize; i++) { + biasData[i] = 0; + } + + node->inputIndex.insert(node->inputIndex.begin() + 1, meta_graph->allTensors.size()); + meta_graph->allTensors.emplace_back(std::move(biasTensor)); + + node->inputIndex.insert(node->inputIndex.begin() + 1, meta_graph->allTensors.size()); + meta_graph->allTensors.emplace_back(std::move(scaleTensor)); + } +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h new file mode 100644 index 00000000000..52297d3018f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_MODEL_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_MODEL_PARSER_H_ + +#include +#include +#include +#include +#include "mindspore/lite/tools/converter/model_parser.h" +#include "tools/converter/parser/caffe/caffe.pb.h" +#include "tools/common/tensor_util.h" + +namespace mindspore { +namespace lite { +class CaffeModelParser : public ModelParser { + public: + CaffeModelParser(); + + virtual ~CaffeModelParser(); + + MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) override; + + private: + void ConvertCaffeBatchNorm(MetaGraphT *meta_graphT); + + STATUS SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache); + + STATUS SetOpOutputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache); + + STATUS SetWeightTensor(const std::vector &weightVec, schema::CNodeT *op, TensorCache *tensorCache); + + STATUS SetAllTensors(const TensorCache &tensorCache, schema::MetaGraphT *subGraphDef); + + STATUS SetGraphTensorIndex(const caffe::NetParameter &proto, + TensorCache *tensorCache, + schema::MetaGraphT *subGraphDef); + + STATUS ParseLayer(const caffe::NetParameter &proto, const caffe::NetParameter &weight, TensorCache *tensorCache, + schema::MetaGraphT *subGraphDef); + + STATUS GetModelInput(const caffe::NetParameter &proto, TensorCache *tensorCache); + + static const std::set skipedLayerType; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_MODEL_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.cc new file mode 100644 index 00000000000..fbee603b147 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "securec/include/securec.h" +#include "ir/dtype/type_id.h" + +namespace mindspore { +namespace lite { +schema::TensorT *ConvertWeight(const caffe::BlobProto &proto) { + std::unique_ptr weight(new schema::TensorT()); + weight->format = schema::Format_NCHW; + std::vector shapeVec; + ConvertShape(proto, &shapeVec); + weight->dims = shapeVec; + weight->dataType = kNumberTypeFloat32; + weight->nodeType = schema::NodeType_ValueNode; + + // cal Weight num + int count = 1; + for (size_t i = 0; i < shapeVec.size(); ++i) { + int dim = shapeVec[i]; + if (dim <= 0) { + // MS_LOGE("Convert weight fail, Blob size invalid"); + return nullptr; + } + if (dim >= INT_MAX / count) { + // MS_LOGE("Convert weight fail, Blob size exceeds INT_MAX, dim:%d, count:%d", dim, count); + return nullptr; + } + count *= dim; + } + + // get weight + std::unique_ptr buf(new (std::nothrow) float[count]()); + if (buf == nullptr) { + return nullptr; + } + if (proto.double_data_size() > 0) { + // datatype double + if (count != proto.double_data_size()) { + // MS_LOGE("Convert weight fail, Blob size does not match shape size, shape size:%d, blob size:%d", count, + // proto.double_data_size()); + return nullptr; + } + + for (int i = 0; i < count; ++i) { + buf[i] = proto.double_data(i); + } + weight->data.resize(count * sizeof(float)); + ::memcpy_s(weight->data.data(), count * sizeof(float), + reinterpret_cast(buf.get()), + count * sizeof(float)); + } else { + // datatype float + if (count != proto.data_size()) { + // MS_LOGE("Convert weight fail, Blob size does not match shape size, shape size:%d, blob.data_size:%d", count, + // proto.data_size()); + return nullptr; + } + weight->data.resize(count * sizeof(float)); + const float *data_ptr = proto.data().data(); + ::memcpy_s(weight->data.data(), count * sizeof(float), (uint8_t *)data_ptr, count * sizeof(float)); + } + weight->refCount = 1; + + return weight.release(); +} + +STATUS ConvertShape(const caffe::BlobProto &proto, std::vector *shape) { + shape->clear(); + + if (proto.has_num() || proto.has_channels() || proto.has_height() || proto.has_width()) { + // num, channels, height, width + shape->push_back(proto.num()); + shape->push_back(proto.channels()); + shape->push_back(proto.height()); + shape->push_back(proto.width()); + } else { + for (int i = 0; i < proto.shape().dim_size(); ++i) { + shape->push_back(proto.shape().dim(i)); + } + } +} +} // namespace lite +} // namespace mindspore +// + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h new file mode 100644 index 00000000000..e7b4f3d82b8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_NODE_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_NODE_PARSER_H_ + +#include +#include +#include "google/protobuf/message.h" +#include "mindspore/lite/schema/inner/model_generated.h" +#include "tools/converter/parser/caffe/caffe.pb.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { + +class CaffeNodeParser { + public: + explicit CaffeNodeParser(const std::string &nodeName) : name(nodeName) {} + + virtual ~CaffeNodeParser() {} + + virtual int Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) = 0; + + protected: + const std::string &name; +}; + +schema::TensorT *ConvertWeight(const caffe::BlobProto &proto); + +STATUS ConvertShape(const caffe::BlobProto &proto, std::vector *shape); +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_NODE_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.cc new file mode 100644 index 00000000000..33085e104b2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +CaffeNodeParserRegistry::CaffeNodeParserRegistry() {} + +CaffeNodeParserRegistry::~CaffeNodeParserRegistry() {} + +CaffeNodeParserRegistry *CaffeNodeParserRegistry::GetInstance() { + static CaffeNodeParserRegistry instance; + return &instance; +} + +CaffeNodeParser *CaffeNodeParserRegistry::GetNodeParser(const std::string &name) { + auto it = parsers.find(name); + if (it != parsers.end()) { + return it->second; + } + return nullptr; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h b/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h new file mode 100644 index 00000000000..75e45bcdfa2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_NODE_PARSER_REGISTRY_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_NODE_PARSER_REGISTRY_H_ + +#include +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "tools/converter/parser/caffe/caffe.pb.h" + +namespace mindspore::lite { +class CaffeNodeParserRegistry { + public: + CaffeNodeParserRegistry(); + + virtual ~CaffeNodeParserRegistry(); + + static CaffeNodeParserRegistry *GetInstance(); + + CaffeNodeParser *GetNodeParser(const std::string &name); + + std::unordered_map parsers; +}; + +class CaffeNodeRegistrar { + public: + CaffeNodeRegistrar(const std::string &name, CaffeNodeParser *parser) { + CaffeNodeParserRegistry::GetInstance()->parsers[name] = parser; + } +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_NODE_PARSER_REGISTRY_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.cc new file mode 100644 index 00000000000..d543578a790 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h" +#include +#include +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "google/protobuf/text_format.h" +#include "google/protobuf/io/coded_stream.h" +#include "securec/include/securec.h" +#include "src/common/file_utils.h" + +namespace mindspore { +namespace lite { +static const int PROTO_READ_BYTES_LIMIT = INT_MAX; // Max size of 2 GB minus 1 byte. +static const int WARNING_THRESHOLD = 536870912 * 2; + +bool ReadProtoFromCodedInputStream(google::protobuf::io::CodedInputStream *coded_stream, + google::protobuf::Message *proto) { + if (proto == nullptr) { + // MS_LOGE("incorrect parameter. nullptr == proto"); + return false; + } + coded_stream->SetTotalBytesLimit(PROTO_READ_BYTES_LIMIT, WARNING_THRESHOLD); + return proto->ParseFromCodedStream(coded_stream); +} + +STATUS ReadProtoFromText(const char *file, google::protobuf::Message *message) { + if (file == nullptr || message == nullptr) { + return RET_ERROR; + } + + std::string realPath = RealPath(file); + if (realPath.empty()) { + // MS_LOGE("Proto file path is '%s' not valid", file); + return RET_ERROR; + } + + std::ifstream fs(realPath.c_str(), std::ifstream::in); + + if (!fs.is_open()) { + // MS_LOGE("Open proto file '%s' failed.", file); + return RET_ERROR; + } + + google::protobuf::io::IstreamInputStream input(&fs); + bool status = google::protobuf::TextFormat::Parse(&input, message); + if (status != true) { + // MS_LOGE("call [google::protobuf::TextFormat::Parse] func status fail, please check your text file."); + return RET_ERROR; + } + + fs.close(); + + return RET_OK; +} + +STATUS ReadProtoFromBinaryFile(const char *file, google::protobuf::Message *message) { + if (file == nullptr || message == nullptr) { + return RET_ERROR; + } + + std::string realPath = RealPath(file); + if (realPath.empty()) { + // MS_LOGE("Weight file path is '%s' not valid", file); + return RET_ERROR; + } + + std::ifstream fs(realPath, std::ifstream::in | std::ifstream::binary); + if (!fs.is_open()) { + // MS_LOGE("Open weight file '%s' failed.", file); + return RET_ERROR; + } + + google::protobuf::io::IstreamInputStream istream(&fs); + google::protobuf::io::CodedInputStream coded_stream(&istream); + + bool success = ReadProtoFromCodedInputStream(&coded_stream, message); + fs.close(); + + if (!success) { + // MS_LOGE("Parse %s failed.", file); + return RET_ERROR; + } + + return RET_OK; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h b/mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h new file mode 100644 index 00000000000..3ee51440df2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_ + +#include +#include +#include "google/protobuf/message.h" + +#include "tools/converter/parser/caffe/caffe.pb.h" +#include "include/errorcode.h" +#include "mindspore/lite/schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +bool ReadProtoFromCodedInputStream(google::protobuf::io::CodedInputStream *coded_stream, + google::protobuf::Message *proto); + +STATUS ReadProtoFromText(const char *file, google::protobuf::Message *message); + +STATUS ReadProtoFromBinaryFile(const char *file, google::protobuf::Message *message); +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.cc new file mode 100644 index 00000000000..c80dd6b21dc --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.cc @@ -0,0 +1,155 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.h" +#include "utils/log_adapter.h" + +const uint32_t INNERPRODUCT_WINDOW_DEFAULT_VALUE = 0; +const uint32_t INNERPRODUCT_PAD_DEFAULT_VALUE = 0; + +namespace mindspore { +namespace lite { +STATUS CaffePoolingParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + std::unique_ptr attr(new schema::PoolingT()); + attr->format = schema::Format_NCHW; + + const caffe::PoolingParameter poolingParam = proto.pooling_param(); + + auto status = ParsePads(poolingParam, attr.get()); + if (status != RET_OK) { + // MS_LOGE("ParsePads for %s failed", proto.name().c_str()); + return RET_ERROR; + } + + status = ParseStrides(poolingParam, attr.get()); + if (status != RET_OK) { + // MS_LOGE("ParseStrides for %s failed", proto.name().c_str()); + return RET_ERROR; + } + + status = ParseWindows(poolingParam, attr.get()); + if (status != RET_OK) { + // MS_LOGE("ParseWindows for %s failed", proto.name().c_str()); + return RET_ERROR; + } + + status = ParsePoolingMode(poolingParam, attr.get()); + if (status != RET_OK) { + // MS_LOGE("ParsePoolingMode for %s failed", proto.name().c_str()); + return RET_ERROR; + } + + if (poolingParam.has_round_mode()) { + if (poolingParam.round_mode() == caffe::PoolingParameter_RoundMode_FLOOR) { + attr->roundMode = schema::RoundMode_FLOOR; + } else if (poolingParam.round_mode() == caffe::PoolingParameter_RoundMode_CEIL) { + attr->roundMode = schema::RoundMode_CEIL; + } else { + MS_ASSERT(false); + } + } + + attr->padMode = schema::PadMode_CAFFE; + op->primitive = std::make_unique(); + op->primitive->value.value = attr.release(); + op->primitive->value.type = schema::PrimitiveType_Pooling; + return RET_OK; +} + +STATUS CaffePoolingParser::ParsePads(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr) { + if (poolingParam.has_pad_h() && poolingParam.has_pad_w()) { + if (poolingParam.has_pad()) { + // MS_LOGE("Either pad or pad_h/w should be specified; not both"); + return RET_ERROR; + } + attr->padLeft = poolingParam.pad_w(); + attr->padRight = poolingParam.pad_w(); + attr->padUp = poolingParam.pad_h(); + attr->padDown = poolingParam.pad_h(); + } else { + attr->padLeft = poolingParam.pad(); + attr->padRight = poolingParam.pad(); + attr->padUp = poolingParam.pad(); + attr->padDown = poolingParam.pad(); + } + return RET_OK; +} + +STATUS CaffePoolingParser::ParseStrides(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr) { + if (poolingParam.has_stride_h() && poolingParam.has_stride_w()) { + if (poolingParam.has_stride()) { + // MS_LOGE("Either stride or stride_h/w should be specified; not both"); + return RET_ERROR; + } + attr->strideH = poolingParam.stride_h(); + attr->strideW = poolingParam.stride_w(); + } else { + attr->strideH = poolingParam.stride(); + attr->strideW = poolingParam.stride(); + } + return RET_OK; +} + +STATUS CaffePoolingParser::ParseWindows(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr) { + if (poolingParam.has_global_pooling() && poolingParam.global_pooling()) { + if (poolingParam.has_kernel_size() || poolingParam.has_kernel_h() || poolingParam.has_kernel_w()) { + // MS_LOGE("With Global_pooling: true Filter size cannot specified"); + return RET_ERROR; + } + attr->windowH = INNERPRODUCT_WINDOW_DEFAULT_VALUE; + attr->windowW = INNERPRODUCT_WINDOW_DEFAULT_VALUE; + attr->global = true; + } else { + if (poolingParam.has_kernel_size() == (poolingParam.has_kernel_h() || poolingParam.has_kernel_w())) { + // MS_LOGE("Filter size is kernel_size OR kernel_h and kernel_w; not both"); + return RET_ERROR; + } + if (!poolingParam.has_kernel_size() && !(poolingParam.has_kernel_h() && poolingParam.has_kernel_w())) { + // MS_LOGE("For non-square filters both kernel_h and kernel_w are required."); + return RET_ERROR; + } + + if (poolingParam.has_kernel_h() && poolingParam.has_kernel_w()) { + attr->windowH = poolingParam.kernel_h(); + attr->windowW = poolingParam.kernel_w(); + } else { + attr->windowH = poolingParam.kernel_size(); + attr->windowW = poolingParam.kernel_size(); + } + } + return RET_OK; +} + +STATUS CaffePoolingParser::ParsePoolingMode(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr) { + if (poolingParam.pool() == caffe::PoolingParameter::MAX) { + attr->poolingMode = schema::PoolMode_MAX_POOLING; + } else if (poolingParam.pool() == caffe::PoolingParameter::AVE) { + attr->poolingMode = schema::PoolMode_MEAN_POOLING; + } else { + // MS_LOGE("Pooling param`s PoolingMode is not MAX either AVE. MindSpore support MAX and AVE only."); + return RET_ERROR; + } + return RET_OK; +} + +CaffeNodeRegistrar g_caffePoolingParser("Pooling", new CaffePoolingParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.h new file mode 100644 index 00000000000..97d042f9ee5 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_pooling_parser.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_POOLING_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_POOLING_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffePoolingParser : public CaffeNodeParser { + public: + CaffePoolingParser() : CaffeNodeParser("pooling") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; + + STATUS ParsePads(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); + + STATUS ParseStrides(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); + + STATUS ParseWindows(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); + + STATUS ParsePoolingMode(const caffe::PoolingParameter &poolingParam, schema::PoolingT *attr); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_POOLING_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.cc new file mode 100644 index 00000000000..0336bd1c4fb --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.h" + +static const float CAFFE_POWER_DEFAULT_POWER = 1.0; +static const float CAFFE_POWER_DEFAULT_SCALE = 1.0; +static const float CAFFE_POWER_DEFAULT_SHIFT = 0.0; + +namespace mindspore { +namespace lite { +STATUS CaffePowerParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + std::unique_ptr attr(new schema::PowerT()); + const caffe::PowerParameter powerParam = proto.power_param(); + if (proto.has_power_param()) { + attr->power = powerParam.has_power() ? powerParam.power() : CAFFE_POWER_DEFAULT_POWER; + attr->scale = powerParam.has_scale() ? powerParam.scale() : CAFFE_POWER_DEFAULT_SCALE; + attr->shift = powerParam.has_shift() ? powerParam.shift() : CAFFE_POWER_DEFAULT_SHIFT; + } else { + attr->power = CAFFE_POWER_DEFAULT_POWER; + attr->scale = CAFFE_POWER_DEFAULT_SCALE; + attr->shift = CAFFE_POWER_DEFAULT_SHIFT; + } + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Power; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +CaffeNodeRegistrar g_caffePowerParser("Power", new CaffePowerParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.h new file mode 100644 index 00000000000..c68b9dd9af2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_power_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_POWER_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_POWER_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffePowerParser : public CaffeNodeParser { + public: + CaffePowerParser() : CaffeNodeParser("power") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_POWER_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.cc new file mode 100644 index 00000000000..1458d9415d7 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0f + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.h" + +namespace mindspore { +namespace lite { +STATUS CaffePReluParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + std::unique_ptr attr(new schema::CaffePReLUT()); + const caffe::PReLUParameter pReluParam = proto.prelu_param(); + + if (pReluParam.has_channel_shared()) { + attr->channelShared = pReluParam.channel_shared(); + } else { + attr->channelShared = false; + } + + if (weight.blobs_size() == 0) { + // MS_LOGE("PRelu No blobs data in layer %s", proto.name().c_str()); + return RET_ERROR; + } + + auto slope = ConvertWeight(weight.blobs(0)); + if (slope == nullptr) { + // MS_LOGE("CaffePRelu convert slope for layer %s failed.", weight.name().c_str()); + return RET_ERROR; + } + weightVec->push_back(slope); + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_CaffePReLU; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +CaffeNodeRegistrar g_caffePReluParser("PReLU", new CaffePReluParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.h new file mode 100644 index 00000000000..cfb53972fbf --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_prelu_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_PRELU_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_PRELU_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffePReluParser : public CaffeNodeParser { + public: + CaffePReluParser() : CaffeNodeParser("pRelu") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_PRELU_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_relu_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_relu_parser.cc new file mode 100644 index 00000000000..49ea560a5d8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_relu_parser.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_relu_parser.h" + +namespace mindspore { +namespace lite { +STATUS CaffeReluParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + std::unique_ptr attr(new schema::ActivationT()); + attr->type = schema::ActivationType_RELU; + op->primitive = std::make_unique(); + op->primitive->value.value = attr.release(); + op->primitive->value.type = schema::PrimitiveType_Activation; + // relu: negative_slope = 0, no parameter; + // leakyrelu: negative_slope != 0; + if (proto.has_relu_param() && proto.relu_param().has_negative_slope()) { + float negative_slope = proto.relu_param().negative_slope(); + + if (0 != negative_slope) { + std::unique_ptr attrLeakyReLu(new schema::LeakyReLUT()); + attrLeakyReLu->negativeSlope = negative_slope; + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_LeakyReLU; + op->primitive->value.value = attrLeakyReLu.release(); + } + } + return RET_OK; +} + +CaffeNodeRegistrar g_caffeReluParser("ReLU", new CaffeReluParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_relu_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_relu_parser.h new file mode 100644 index 00000000000..618a53d694d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_relu_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_RELU_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_RELU_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeReluParser : public CaffeNodeParser { + public: + CaffeReluParser() : CaffeNodeParser("relu") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_RELU_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.cc new file mode 100644 index 00000000000..ee0e461e984 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.h" + +namespace mindspore { +namespace lite { +STATUS CaffeReshapeParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + std::unique_ptr attr(new schema::ReshapeT()); + attr->format = schema::Format_NCHW; + + const caffe::ReshapeParameter reshapeParam = proto.reshape_param(); + + if (!reshapeParam.has_shape()) { + // MS_LOGE("Reshape has no shape info, ret fail"); + return RET_ERROR; + } + + const caffe::BlobShape &blob_shape = reshapeParam.shape(); + for (int i = 0; i < blob_shape.dim_size(); i++) { + attr->shape.push_back(blob_shape.dim(i)); + } + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Reshape; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +CaffeNodeRegistrar g_caffeReshapeParser("Reshape", new CaffeReshapeParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.h new file mode 100644 index 00000000000..142751e4573 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_reshape_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_RESHAPE_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_RESHAPE_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeReshapeParser : public CaffeNodeParser { + public: + CaffeReshapeParser() : CaffeNodeParser("reshape") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_RESHAPE_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.cc new file mode 100644 index 00000000000..d2199f2abd8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.cc @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.h" + +const int32_t NCHW_DIM_C = 1; +const int32_t DIM_DEFAULT_SIZE = 4; + +namespace mindspore { +namespace lite { +STATUS CaffeScaleParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + std::unique_ptr attr(new schema::ScaleT()); + attr->format = schema::Format_NCHW; + + if (weight.blobs_size() + weight.bottom_size() < 2) { + // MS_LOGE("Scale bottom size:%d, blobs size:%d invalid in layer %s", weight.bottom_size(), weight.blobs_size(), + // weight.name().c_str()); + return RET_ERROR; + } + + const caffe::ScaleParameter scaleParam = weight.scale_param(); + int32_t axis = scaleParam.axis(); // NCHW_DIM_C; + uint32_t axis_index = NCHW_DIM_C; + + if (GetAxisIndex(axis, &axis_index)) { + // MS_LOGE("scale get axis failed for layer %s.", weight.name().c_str()); + } + + // parse scale + // todo expect only weight as scale not bias + if (weight.blobs().size() == 1) { + auto scale = ConvertWeight(weight.blobs(0)); + if (scale == nullptr) { + // MS_LOGE("Scale Convert blobs(0) for layer %s failed.", weight.name().c_str()); + return RET_ERROR; + } + weightVec->push_back(scale); + } else if (weight.blobs().size() >= 2) { + auto scale = ConvertWeight(weight.blobs(0)); + if (scale == nullptr) { + // MS_LOGE("Scale Convert blobs(0) for layer %s failed.", weight.name().c_str()); + return RET_ERROR; + } + weightVec->push_back(scale); + + // parse bias + bool scaleBias = scaleParam.bias_term(); + if (scaleBias) { + auto bias = ConvertWeight(weight.blobs_size() > 1 ? weight.blobs(1) : weight.blobs(0)); + if (bias == nullptr) { + // MS_LOGE("Scale Convert blobs(1) for layer %s failed.", weight.name().c_str()); + return RET_ERROR; + } + weightVec->push_back(bias); + } + } + op->primitive = std::make_unique(); + op->primitive->value.value = attr.release(); + op->primitive->value.type = schema::PrimitiveType_Scale; + return RET_OK; +} + +STATUS CaffeScaleParser::GetAxisIndex(const int32_t &axis, uint32_t *axis_index) { + if (axis < -DIM_DEFAULT_SIZE || axis >= DIM_DEFAULT_SIZE) { + // MS_LOGE("Scale axis value(%d) is not correct, ", axis); + return RET_PARAM_INVALID; + } + + if (axis == -1) { + // MS_LOGW("axis with -1 may lead to calculation errors when input less than 4 dims."); + } + + *axis_index = (axis + DIM_DEFAULT_SIZE) % DIM_DEFAULT_SIZE; + return RET_OK; +} + +CaffeNodeRegistrar g_caffeScaleParser("Scale", new CaffeScaleParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.h new file mode 100644 index 00000000000..cdd5c70726e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_scale_parser.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SCALE_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SCALE_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeScaleParser : public CaffeNodeParser { + public: + CaffeScaleParser() : CaffeNodeParser("scale") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; + + STATUS GetAxisIndex(const int32_t &axis, uint32_t *axis_index); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SCALE_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_sigmoid_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_sigmoid_parser.cc new file mode 100644 index 00000000000..20c2590ff83 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_sigmoid_parser.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_sigmoid_parser.h" + +namespace mindspore { +namespace lite { +STATUS CaffeSigmoidParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + std::unique_ptr attr(new schema::ActivationT()); + attr->type = schema::ActivationType_SIGMOID; + op->primitive = std::make_unique(); + op->primitive->value.value = attr.release(); + op->primitive->value.type = schema::PrimitiveType_Activation; + return RET_OK; +} + +CaffeNodeRegistrar g_caffeSigmoidParser("Sigmoid", new CaffeSigmoidParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_sigmoid_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_sigmoid_parser.h new file mode 100644 index 00000000000..5f795b11d32 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_sigmoid_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SIGMOID_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SIGMOID_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeSigmoidParser : public CaffeNodeParser { + public: + CaffeSigmoidParser() : CaffeNodeParser("sigmoid") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SIGMOID_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.cc new file mode 100644 index 00000000000..399be822c1c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.h" +#include "utils/log_adapter.h" + +static const int32_t CAFFE_SOFTMAX_DEFAULT_AXIS = 1; + +namespace mindspore { +namespace lite { +STATUS CaffeSoftmaxParser::Parse(const caffe::LayerParameter &proto, + const caffe::LayerParameter &weight, + schema::CNodeT *op, + std::vector *weightVec) { + std::unique_ptr attr(new schema::SoftMaxT()); + if (proto.has_softmax_param() && proto.softmax_param().has_axis()) { + if (proto.softmax_param().axis() == -1) { + MS_LOG(ERROR) << "axis with -1 may lead to calculation errors when input less than 4 dims."; + } + attr->axis = proto.softmax_param().axis(); + } else { + attr->axis = CAFFE_SOFTMAX_DEFAULT_AXIS; + } + op->primitive = std::make_unique(); + op->primitive->value.value = attr.release(); + op->primitive->value.type = schema::PrimitiveType_SoftMax; + return RET_OK; +} + +CaffeNodeRegistrar g_caffeSoftmaxParser("Softmax", new CaffeSoftmaxParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.h new file mode 100644 index 00000000000..f8675d4fd58 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_softmax_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SOFTMAX_PARSER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SOFTMAX_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeSoftmaxParser : public CaffeNodeParser { + public: + CaffeSoftmaxParser() : CaffeNodeParser("softmax") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_CAFFE_CAFFE_SOFTMAX_PARSER_H_ + diff --git a/mindspore/lite/tools/converter/parser/onnx/CMakeLists.txt b/mindspore/lite/tools/converter/parser/onnx/CMakeLists.txt new file mode 100644 index 00000000000..d0b6cd3d52d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/CMakeLists.txt @@ -0,0 +1,5 @@ +file(GLOB_RECURSE ONNX_SRC_LIST ${CMAKE_CURRENT_SOURCE_DIR}/*.cc) + +add_library(onnx_parser_mid OBJECT + ${ONNX_SRC_LIST} + ) diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx.proto b/mindspore/lite/tools/converter/parser/onnx/onnx.proto new file mode 100644 index 00000000000..093fcf99c0d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx.proto @@ -0,0 +1,569 @@ +// +// WARNING: This file is automatically generated! Please edit onnx.in.proto. +// + + +// Copyright (c) ONNX Project Contributors. +// Licensed under the MIT license. + +syntax = "proto2"; + +package onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Release +// +// We are still in the very early stage of defining ONNX. The current +// version of ONNX is a starting point. While we are actively working +// towards a complete spec, we would like to get the community involved +// by sharing our working version of ONNX. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. + // For the IR, we are using simple numbers starting with with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; + + // IR_VERSION 2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x0000000000000002; + + // IR VERSION 3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION_2019_3_18 = 0x0000000000000005; + + // IR VERSION 6 published on Sep 19, 2019 + // - Add support for sparse tensor constants stored in model. + // - Add message SparseTensorProto + // - Add sparse initializers + IR_VERSION = 0x0000000000000006; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + SPARSE_TENSOR = 11; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + SPARSE_TENSORS = 12; + } + + // The name field MUST be present for this version of the IR. + optional string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + optional string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + optional string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field hueristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accomodate proto3 implementations. + optional AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + optional float f = 2; // float + optional int64 i = 3; // int + optional bytes s = 4; // UTF-8 string + optional TensorProto t = 5; // tensor value + optional GraphProto g = 6; // graph + optional SparseTensorProto sparse_tensor = 22; // sparse tensor value + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph + repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + optional string name = 1; // namespace Value + // This field MUST be present in this version of the IR for + // inputs and outputs of the top-level graph. + optional TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + optional string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + optional string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + optional string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + optional string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + optional string doc_string = 6; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + optional int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + optional string domain = 4; + + // The version of the graph encoded. See Version enum below. + optional int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + optional string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + optional GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + optional string key = 1; + optional string value= 2; +}; + +message TensorAnnotation { + optional string tensor_name = 1; + // pairs to annotate tensor specified by above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + optional string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each TensorProto entry must have a distinct name (within the list) that + // MAY also appear in the input list. + repeated TensorProto initializer = 5; + + // Initializers (see above) stored in sparse format. + repeated SparseTensorProto sparse_initializer = 15; + + // A human-readable documentation for this graph. Markdown is allowed. + optional string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + // This field MUST have a valid TensorProto.DataType value + optional int32 data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + optional int64 begin = 1; + optional int64 end = 2; + } + optional Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + optional string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + optional string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + optional bytes raw_data = 9; + + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + optional DataLocation data_location = 14; + + // For double + // Complex128 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// A serialized sparse-tensor value +message SparseTensorProto { + // The sequence of non-default values are encoded as a tensor of shape [NNZ]. + // The default-value is zero for numeric tensors, and empty-string for string tensors. + optional TensorProto values = 1; + + // The indices of the non-default values, which may be stored in one of two formats. + // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value + // corresponding to the j-th index of the i-th value (in the values tensor). + // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value + // must be the linearized-index of the i-th value (in the values tensor). + // The linearized-index can be converted into an index tuple (k_1,...,k_rank) + // using the shape provided below. + // The indices must appear in ascending order without duplication. + // In the first format, the ordering is lexicographic-ordering: + // e.g., index-value [1,4] must appear before [2,1] + optional TensorProto indices = 2; + + // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] + repeated int64 dims = 3; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + optional string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + optional int32 elem_type = 1; + optional TensorShapeProto shape = 2; + } + + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + optional TypeProto elem_type = 1; + }; + + // map + message Map { + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + optional int32 key_type = 1; + // This field MUST be present for this version of the IR. + optional TypeProto value_type = 2; + }; + + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + optional string denotation = 6; +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + optional string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + optional int64 version = 2; +} \ No newline at end of file diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc new file mode 100644 index 00000000000..c6990ad0854 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxArgMaxParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::ArgMaxT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "axis") { + attr->axis = static_cast(onnx_node_attr.i()); + } else if (attribute_name == "keepdims") { + attr->keepDims = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_ArgMax; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxArgMaxParser("ArgMax", new OnnxArgMaxParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h new file mode 100644 index 00000000000..609aa539569 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_argmax_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_ARGMAX_PARSER_H +#define MS_ONNX_ARGMAX_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxArgMaxParser : public OnnxNodeParser { + public: + OnnxArgMaxParser() : OnnxNodeParser("ArgMax") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_ARGMAX_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc new file mode 100644 index 00000000000..44ee34f8fef --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc @@ -0,0 +1,270 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxAddParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Add; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxSubParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Sub; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxMulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Mul; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxDivParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_RealDiv; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxMeanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Mean; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxPowParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Power; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxEqualParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Equal; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxLessParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Less; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxGreaterParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Greater; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxMinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Min; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxEltwiseParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + std::unique_ptr attr(new schema::EltwiseT()); + if (onnx_node.op_type() == "Prod") { + attr->mode = schema::EltwiseMode_PROD; + } else if (onnx_node.op_type() == "Prod") { + attr->mode = schema::EltwiseMode_SUM; + } else if (onnx_node.op_type() == "Sum") { + attr->mode = schema::EltwiseMode_MAXIMUM; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Eltwise; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +STATUS OnnxFloorParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Floor; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxAbsParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Abs; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxNegParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Neg; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxExpParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Exp; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxCosParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Cos; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxSinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Sin; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxSqrtParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Sqrt; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxCeilParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Ceil; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxLogParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Log; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxTanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Tan; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxAtanParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Atan; + op->primitive->value.value = nullptr; + } + return RET_OK; +} +STATUS OnnxAsinParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Asin; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +STATUS OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser()); +OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser()); +OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser()); +OnnxNodeRegistrar g_onnxMulParser("Mul", new OnnxMulParser()); +OnnxNodeRegistrar g_onnxDivParser("Div", new OnnxDivParser()); +OnnxNodeRegistrar g_onnxMeanParser("Mean", new OnnxMeanParser()); +OnnxNodeRegistrar g_onnxPowParser("Power", new OnnxPowParser()); +OnnxNodeRegistrar g_onnxEqualParser("Equal", new OnnxEqualParser()); +OnnxNodeRegistrar g_onnxLessParser("Less", new OnnxLessParser()); +OnnxNodeRegistrar g_onnxGreaterParser("Greater", new OnnxGreaterParser()); +OnnxNodeRegistrar g_onnxMinParser("Min", new OnnxMinParser()); +OnnxNodeRegistrar g_onnxProdParser("Prod", new OnnxEltwiseParser()); +OnnxNodeRegistrar g_onnxSumParser("Sum", new OnnxEltwiseParser()); +OnnxNodeRegistrar g_onnxMaxParser("Max", new OnnxEltwiseParser()); +OnnxNodeRegistrar g_onnxFloorParser("Floor", new OnnxFloorParser()); +OnnxNodeRegistrar g_onnxAbsParser("Abs", new OnnxAbsParser()); +OnnxNodeRegistrar g_onnxNegParser("Neg", new OnnxNegParser()); +OnnxNodeRegistrar g_onnxExpParser("Exp", new OnnxExpParser()); +OnnxNodeRegistrar g_onnxCosParser("Cos", new OnnxCosParser()); +OnnxNodeRegistrar g_onnxSinParser("Sin", new OnnxSinParser()); +OnnxNodeRegistrar g_onnxSqrtParser("Sqrt", new OnnxSqrtParser()); +OnnxNodeRegistrar g_onnxCeilParser("Ceil", new OnnxCeilParser()); +OnnxNodeRegistrar g_onnxLogParser("Log", new OnnxLogParser()); +OnnxNodeRegistrar g_onnxTanParser("Tan", new OnnxTanParser()); +OnnxNodeRegistrar g_onnxAtanParser("Atan", new OnnxAtanParser()); +OnnxNodeRegistrar g_onnxAsinParser("Asin", new OnnxAsinParser()); +OnnxNodeRegistrar g_onnxTanhParser("Tanh", new OnnxTanhParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h new file mode 100644 index 00000000000..bbb62e908f7 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h @@ -0,0 +1,171 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_ARITHMETIC_OPREATION_PARSER_H +#define MS_ONNX_ARITHMETIC_OPREATION_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxAddParser : public OnnxNodeParser { + public: + OnnxAddParser() : OnnxNodeParser("Add") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxSubParser : public OnnxNodeParser { + public: + OnnxSubParser() : OnnxNodeParser("Sub") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxMulParser : public OnnxNodeParser { + public: + OnnxMulParser() : OnnxNodeParser("Mul") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxDivParser : public OnnxNodeParser { + public: + OnnxDivParser() : OnnxNodeParser("Div") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxMeanParser : public OnnxNodeParser { + public: + OnnxMeanParser() : OnnxNodeParser("Mean") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxPowParser : public OnnxNodeParser { + public: + OnnxPowParser() : OnnxNodeParser("Power") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxEqualParser : public OnnxNodeParser { + public: + OnnxEqualParser() : OnnxNodeParser("Equal") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxLessParser : public OnnxNodeParser { + public: + OnnxLessParser() : OnnxNodeParser("Less") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxGreaterParser : public OnnxNodeParser { + public: + OnnxGreaterParser() : OnnxNodeParser("Greater") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxMinParser : public OnnxNodeParser { + public: + OnnxMinParser() : OnnxNodeParser("Min") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxEltwiseParser : public OnnxNodeParser { + public: + OnnxEltwiseParser() : OnnxNodeParser("Eltwise") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxFloorParser : public OnnxNodeParser { + public: + OnnxFloorParser() : OnnxNodeParser("Floor") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxAbsParser : public OnnxNodeParser { + public: + OnnxAbsParser() : OnnxNodeParser("Abs") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxNegParser : public OnnxNodeParser { + public: + OnnxNegParser() : OnnxNodeParser("Neg") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxExpParser : public OnnxNodeParser { + public: + OnnxExpParser() : OnnxNodeParser("Exp") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxCosParser : public OnnxNodeParser { + public: + OnnxCosParser() : OnnxNodeParser("Cos") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxSinParser : public OnnxNodeParser { + public: + OnnxSinParser() : OnnxNodeParser("Sin") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxSqrtParser : public OnnxNodeParser { + public: + OnnxSqrtParser() : OnnxNodeParser("Sqrt") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxCeilParser : public OnnxNodeParser { + public: + OnnxCeilParser() : OnnxNodeParser("Ceil") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxLogParser : public OnnxNodeParser { + public: + OnnxLogParser() : OnnxNodeParser("Log") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxTanParser : public OnnxNodeParser { + public: + OnnxTanParser() : OnnxNodeParser("Tan") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxAtanParser : public OnnxNodeParser { + public: + OnnxAtanParser() : OnnxNodeParser("Atan") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxAsinParser : public OnnxNodeParser { + public: + OnnxAsinParser() : OnnxNodeParser("Asin") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxTanhParser : public OnnxNodeParser { + public: + OnnxTanhParser() : OnnxNodeParser("Tanh") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_ARITHMETIC_OPREATION_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc new file mode 100644 index 00000000000..d4ea5cdde5e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h" +#include + +namespace mindspore { +namespace lite { +STATUS OnnxBatchNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::FusedBatchNormT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + if (onnx_node_attr.name() == "epsilon") { + attr->epsilon = onnx_node_attr.f(); + } else if (onnx_node_attr.name() == "momentum") { + attr->momentum = onnx_node_attr.f(); + } else if (onnx_node_attr.name() == "spatial") { + attr->spatial = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxBatchNormParser("BatchNormalization", new OnnxBatchNormParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h new file mode 100644 index 00000000000..c6b6fdb70c2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_batchnorm_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_ADD_PARSER_H +#define MS_ONNX_ADD_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxBatchNormParser : public OnnxNodeParser { + public: + OnnxBatchNormParser() : OnnxNodeParser("BatchNormalization") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_ADD_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc new file mode 100644 index 00000000000..be5229ed073 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h" + +// using namespace mindspore::predict; +// using namespace onnx; +// using namespace std; +namespace mindspore { +namespace lite { +STATUS OnnxBiasAddParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::BiasAddT()); + // use channel dim as axis + attr->axis = {1}; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_BiasAdd; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxBiasAddParser("BiasAdd", new OnnxBiasAddParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h new file mode 100644 index 00000000000..06497be3f08 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_biasadd_parser.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_BIASADD_PARSER_H +#define MS_ONNX_BIASADD_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxBiasAddParser : public OnnxNodeParser { + public: + OnnxBiasAddParser() : OnnxNodeParser("BiasAdd") {} + + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_BIASADD_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc new file mode 100644 index 00000000000..66ef7542ebd --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + unique_ptr attr(new schema::CastT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "to") { + attr->dstT = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Cast; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxCastParser("Cast", new OnnxCastParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h new file mode 100644 index 00000000000..8d028379aa3 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_CAST_PARSER_H +#define MS_ONNX_CAST_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxCastParser : public OnnxNodeParser { + public: + OnnxCastParser() : OnnxNodeParser("Cast") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_CAST_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc new file mode 100644 index 00000000000..d27994b365a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + unique_ptr attr(new schema::ClipT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "max") { + attr->max = onnx_node_attr.f(); + } else if (attribute_name == "min") { + attr->min = onnx_node_attr.f(); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Clip; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxClipParser("Clip", new OnnxClipParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h new file mode 100644 index 00000000000..00532e73eb2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_CLIP_PARSER_H +#define MS_ONNX_CLIP_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxClipParser : public OnnxNodeParser { + public: + OnnxClipParser() : OnnxNodeParser("Clip") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_ARGMAX_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc new file mode 100644 index 00000000000..20549991921 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxConcatParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::ConcatT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "axis") { + attr->axis = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Concat; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxConcatParser("Concat", new OnnxConcatParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h new file mode 100644 index 00000000000..b38039cd7bb --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_concat_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_CONCAT_PARSER_H +#define MS_ONNX_CONCAT_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxConcatParser : public OnnxNodeParser { + public: + OnnxConcatParser() : OnnxNodeParser("Concat") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_CONCAT_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc new file mode 100644 index 00000000000..ff618aba0a9 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxConstantParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Constant; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxConstantParser("Constant", new OnnxConstantParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h new file mode 100644 index 00000000000..0356057b281 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_CONSTANT_PARSER_H +#define MS_ONNX_CONSTANT_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxConstantParser : public OnnxNodeParser { + public: + OnnxConstantParser() : OnnxNodeParser("Constant") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_CONSTANT_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc new file mode 100644 index 00000000000..89ed7efb864 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc @@ -0,0 +1,172 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h" + +namespace mindspore { +namespace lite { +bool OnnxConvParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr) { + if (attr == nullptr || attr->group != attr->channelIn) { + return false; + } + std::unique_ptr depthwiseConv2DParam(new (std::nothrow) schema::DepthwiseConv2DT()); + if (depthwiseConv2DParam == nullptr) { + // MS_LOGW("new DepthwiseConv2DT failed"); + return false; + } + depthwiseConv2DParam->format = attr->format; + depthwiseConv2DParam->channelIn = attr->channelIn; + depthwiseConv2DParam->channelMultiplier = attr->channelOut / attr->channelIn; + depthwiseConv2DParam->kernelW = attr->kernelW; + depthwiseConv2DParam->kernelH = attr->kernelH; + depthwiseConv2DParam->strideW = attr->strideW; + depthwiseConv2DParam->strideH = attr->strideH; + depthwiseConv2DParam->padMode = attr->padMode; + depthwiseConv2DParam->padUp = attr->padUp; + depthwiseConv2DParam->padDown = attr->padDown; + depthwiseConv2DParam->padLeft = attr->padLeft; + depthwiseConv2DParam->padRight = attr->padRight; + depthwiseConv2DParam->dilateW = attr->dilateW; + depthwiseConv2DParam->dilateH = attr->dilateH; + depthwiseConv2DParam->hasBias = attr->hasBias; + depthwiseConv2DParam->activationType = attr->activationType; + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + delete (op->primitive->value.value); + op->primitive->value.value = depthwiseConv2DParam.release(); + return true; +} + +STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + auto attr = new schema::Conv2DT(); + // set opdef each attr params + for (const auto &onnx_node_attr : onnx_node.attribute()) { + if (onnx_node_attr.name() == "group") { + attr->group = static_cast(onnx_node_attr.i()); + } else if (onnx_node_attr.name() == "dilations") { + if (onnx_node_attr.ints().size() != 2) { + // MS_LOGE("dilations size %d is not 2", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->dilateW = static_cast(onnx_node_attr.ints(0)); + attr->dilateH = static_cast(onnx_node_attr.ints(1)); + } else if (onnx_node_attr.name() == "kernels") { + if (onnx_node_attr.ints().size() != 2) { + // MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->kernelH = static_cast(onnx_node_attr.ints(0)); + attr->kernelW = static_cast(onnx_node_attr.ints(1)); + } else if (onnx_node_attr.name() == "kernel_shape") { + if (onnx_node_attr.ints().size() != 2) { + // MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->kernelW = static_cast(onnx_node_attr.ints(0)); + attr->kernelH = static_cast(onnx_node_attr.ints(1)); + } else if (onnx_node_attr.name() == "auto_pad") { + attr->padMode = GetOnnxPadMode(onnx_node_attr); + } else if (onnx_node_attr.name() == "pads") { + if (onnx_node_attr.ints().size() != 4) { + // MS_LOGE("pads size %d is not 4", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->padUp = static_cast(onnx_node_attr.ints(0)); + attr->padLeft = static_cast(onnx_node_attr.ints(1)); + attr->padDown = static_cast(onnx_node_attr.ints(2)); + attr->padRight = static_cast(onnx_node_attr.ints(3)); + } else if (onnx_node_attr.name() == "strides") { + if (onnx_node_attr.ints().size() != 2) { + // MS_LOGE("strides size %d is not 2", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->strideW = static_cast(onnx_node_attr.ints(0)); + attr->strideH = static_cast(onnx_node_attr.ints(1)); + } else if (onnx_node_attr.name() == "order") { + if (onnx_node_attr.s() == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + // MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str()); + return RET_ERROR; + } + } + } + + const auto &onnx_conv_weight = onnx_node.input(1); + if (onnx_node.op_type() == "Conv") { + auto nodeIter = + std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), + [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); + if (nodeIter == onnx_graph.initializer().end()) { + // MS_LOGE("not find node: %s", onnx_conv_weight.c_str()) + return RET_ERROR; + } + std::vector weight_shape; + auto size = (*nodeIter).dims_size(); + for (int i = 0; i < size; ++i) { + weight_shape.emplace_back((*nodeIter).dims(i)); + } + attr->channelOut = weight_shape[0]; + attr->channelIn = weight_shape[1] * attr->group; + } else { + auto nodeIter = + std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), + [onnx_conv_weight](const onnx::NodeProto &proto) { return proto.output(0) == onnx_conv_weight; }); + if (nodeIter == onnx_graph.node().end()) { + // MS_LOGE("can not find node: %s", onnx_conv_weight.c_str()) + return RET_ERROR; + } + std::vector dims; + auto iter = std::find_if((*nodeIter).attribute().begin(), (*nodeIter).attribute().end(), + [](const onnx::AttributeProto &attr) { return attr.name() == "shape"; }); + if (iter != (*nodeIter).attribute().end()) { + dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end()); + } + attr->channelOut = dims[0]; + attr->channelIn = dims[3] * attr->group; + } + attr->format = schema::Format_NCHW; + attr->hasBias = onnx_node.input().size() == 3; + if (onnx_node.op_type() == "ConvRelu" || onnx_node.op_type() == "Int8ConvRelu") { + attr->activationType = schema::ActivationType_RELU; + } else { + attr->activationType = schema::ActivationType_NO_ACTIVATION; + } + + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Conv2D; + op->primitive->value.value = attr; + + if (attr->group != 1) { + if (!ParseGroupConvolution(op, attr)) { + delete attr; + // MS_LOGE("Convert Convolution to Depthwise failed"); + return RET_ERROR; + } + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxConvParser("Conv", new OnnxConvParser()); +OnnxNodeRegistrar g_onnxInt8ConvParser("Int8Conv", new OnnxConvParser()); +OnnxNodeRegistrar g_onnxConvReluParser("ConvRelu", new OnnxConvParser()); +OnnxNodeRegistrar g_onnxInt8ConvReluParser("Int8ConvRelu", new OnnxConvParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h new file mode 100644 index 00000000000..73fa7e531cc --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_CONV_PARSER_H +#define MS_ONNX_CONV_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxConvParser : public OnnxNodeParser { + public: + OnnxConvParser() : OnnxNodeParser("Conv") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + private: + bool ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr); +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_CONV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_converter.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_converter.cc new file mode 100755 index 00000000000..2e7ecb90d4c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_converter.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_converter.h" + +namespace mindspore { +namespace lite { +OnnxConverter::OnnxConverter() { + modelParser = new OnnxModelParser(); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_converter.h b/mindspore/lite/tools/converter/parser/onnx/onnx_converter.h new file mode 100755 index 00000000000..a6fbc75172c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_converter.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_CONVERTER_H +#define MS_ONNX_CONVERTER_H +#include +#include +#include "mindspore/lite/tools/converter/converter.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h" +#include "mindspore/lite/tools/converter/graphdef_transform.h" + +namespace mindspore { +namespace lite { +class OnnxConverter : public Converter { + public: + OnnxConverter(); + + ~OnnxConverter() override = default; +}; +} // namespace lite +} // namespace mindspore + +#endif // MS_ONNX_CONVERTER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc new file mode 100644 index 00000000000..7e6e021e987 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc @@ -0,0 +1,154 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h" + +namespace mindspore { +namespace lite { +bool OnnxDeConvParser::ParseGroupDeConvolution(schema::CNodeT *op, schema::DeConv2DT *attr) { + if (attr == nullptr || attr->group != attr->channelOut) { + return false; + } + auto deDepthwiseConv2DParam(new (std::nothrow) schema::DeDepthwiseConv2DT()); + if (deDepthwiseConv2DParam == nullptr) { + // MS_LOGW("new DeDepthwiseConv2DT failed"); + return false; + } + deDepthwiseConv2DParam->format = attr->format; + deDepthwiseConv2DParam->channelIn = attr->channelIn; + deDepthwiseConv2DParam->channelMultiplier = attr->channelOut / attr->channelIn; + deDepthwiseConv2DParam->kernelW = attr->kernelW; + deDepthwiseConv2DParam->kernelH = attr->kernelH; + deDepthwiseConv2DParam->strideW = attr->strideW; + deDepthwiseConv2DParam->strideH = attr->strideH; + deDepthwiseConv2DParam->padMode = attr->padMode; + deDepthwiseConv2DParam->padUp = attr->padUp; + deDepthwiseConv2DParam->padDown = attr->padDown; + deDepthwiseConv2DParam->padLeft = attr->padLeft; + deDepthwiseConv2DParam->padRight = attr->padRight; + deDepthwiseConv2DParam->dilateW = attr->dilateW; + deDepthwiseConv2DParam->dilateH = attr->dilateH; + deDepthwiseConv2DParam->hasBias = attr->hasBias; + deDepthwiseConv2DParam->activationType = attr->activationType; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + delete (op->primitive->value.value); + op->primitive->value.value = deDepthwiseConv2DParam; + } + return true; +} + +STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + auto attr = new schema::DeConv2DT(); + // set opdef each attr params + for (const auto &onnx_node_attr : onnx_node.attribute()) { + if (onnx_node_attr.name() == "group") { + attr->group = static_cast(onnx_node_attr.i()); + } else if (onnx_node_attr.name() == "dilations") { + if (onnx_node_attr.ints().size() != 2) { + // MS_LOGE("dilations size %d is not 2", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->dilateW = static_cast(onnx_node_attr.ints(0)); + attr->dilateH = static_cast(onnx_node_attr.ints(1)); + } else if (onnx_node_attr.name() == "kernels") { + if (onnx_node_attr.ints().size() != 2) { + // MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->kernelH = static_cast(onnx_node_attr.ints(0)); + attr->kernelW = static_cast(onnx_node_attr.ints(1)); + } else if (onnx_node_attr.name() == "kernel_shape") { + if (onnx_node_attr.ints().size() != 2) { + // MS_LOGE("kernel_shape size %d is not 2", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->kernelW = static_cast(onnx_node_attr.ints(0)); + attr->kernelH = static_cast(onnx_node_attr.ints(1)); + } else if (onnx_node_attr.name() == "auto_pad") { + attr->padMode = GetOnnxPadMode(onnx_node_attr); + } else if (onnx_node_attr.name() == "pads") { + if (onnx_node_attr.ints().size() != 4) { + // MS_LOGE("pads size %d is not 4", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->padUp = static_cast(onnx_node_attr.ints(0)); + attr->padLeft = static_cast(onnx_node_attr.ints(1)); + attr->padDown = static_cast(onnx_node_attr.ints(2)); + attr->padRight = static_cast(onnx_node_attr.ints(3)); + } else if (onnx_node_attr.name() == "strides") { + if (onnx_node_attr.ints().size() != 2) { + // MS_LOGE("strides size %d is not 2", onnx_node_attr.ints().size()); + return RET_ERROR; + } + attr->strideW = static_cast(onnx_node_attr.ints(0)); + attr->strideH = static_cast(onnx_node_attr.ints(1)); + } else if (onnx_node_attr.name() == "order") { + if (onnx_node_attr.s() == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + // MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str()); + return RET_ERROR; + } + } + } + + const auto &onnx_conv_weight = onnx_node.input(1); + auto nodeIter = + std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), + [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); + if (nodeIter == onnx_graph.initializer().end()) { + // MS_LOGE("not find node: %s", onnx_conv_weight.c_str()) + return RET_ERROR; + } + std::vector weight_shape; + auto size = (*nodeIter).dims_size(); + for (int i = 0; i < size; ++i) { + weight_shape.emplace_back((*nodeIter).dims(i)); + } + MS_ASSERT(weight_shape.size() == 4); + attr->channelIn = weight_shape[0]; + attr->channelOut = weight_shape[1] * attr->group; + + attr->format = schema::Format_NCHW; + attr->hasBias = onnx_node.input().size() == 3; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DeConv2D; + op->primitive->value.value = attr; + } + + if (attr->group != 1) { + if (!ParseGroupDeConvolution(op, attr)) { + delete attr; + // MS_LOGE("Convert DeConvolution to DeDepthwise failed"); + return RET_ERROR; + } + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxDeConvParser("ConvTranspose", new OnnxDeConvParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h new file mode 100644 index 00000000000..b4fba8bf4a9 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_DECONV_PARSER_H +#define MS_ONNX_DECONV_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxDeConvParser : public OnnxNodeParser { + public: + OnnxDeConvParser() : OnnxNodeParser("DeConv") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; + + private: + bool ParseGroupDeConvolution(schema::CNodeT *op, schema::DeConv2DT *attr); +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_DECONV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc new file mode 100644 index 00000000000..ee819cb02da --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxDepthToSpaceParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::DepthToSpaceT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto& attribute_name = onnx_node_attr.name(); + if (attribute_name == "blocksize") { + attr->blockSize = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DepthToSpace; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxDepthToSpaceParser("DepthToSpace", new OnnxDepthToSpaceParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h new file mode 100644 index 00000000000..834d71ccc9f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_DEPTH_TO_SPACE_PARSER_H +#define MS_ONNX_DEPTH_TO_SPACE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxDepthToSpaceParser : public OnnxNodeParser { + public: + OnnxDepthToSpaceParser() : OnnxNodeParser("DepthToSpace") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_DEPTH_TO_SPACE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc new file mode 100644 index 00000000000..451a4c4a699 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::DropoutT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "ratio") { + attr->ratio = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Dropout; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxDropoutParser("Dropout", new OnnxDropoutParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h new file mode 100644 index 00000000000..14898f4616d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_ARGMAX_PARSER_H +#define MS_ONNX_ARGMAX_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxDropoutParser : public OnnxNodeParser { + public: + OnnxDropoutParser() : OnnxNodeParser("Dropout") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_ARGMAX_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.cc new file mode 100644 index 00000000000..b1497c68a3d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxEluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + unique_ptr attr(new schema::EluT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto& attribute_name = onnx_node_attr.name(); + if (attribute_name == "alpha") { + attr->alpha = onnx_node_attr.f(); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Elu; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxEluParser("Elu", new OnnxEluParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h new file mode 100644 index 00000000000..4267609791b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_elu_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_ELU_PARSER_H +#define MS_ONNX_ELU_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxEluParser : public OnnxNodeParser { + public: + OnnxEluParser() : OnnxNodeParser("Elu") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_ELU_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc new file mode 100644 index 00000000000..ad155fbf4ed --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Broadcast; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxExpandSpaceParser("Expand", new OnnxExpandParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h new file mode 100644 index 00000000000..604281dbfb1 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_expand_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_EXPAND_PARSER_H +#define MS_ONNX_EXPAND_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxExpandParser : public OnnxNodeParser { + public: + OnnxExpandParser() : OnnxNodeParser("Expand") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_EXPAND_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc new file mode 100644 index 00000000000..d06c620731d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxFlattenParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::ReshapeT()); + int axis = 1; + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "axis") { + axis = static_cast(onnx_node_attr.i()); + } + } + for (int i = 0; i < axis; ++i) { + attr->shape.emplace_back(0); + } + attr->shape.emplace_back(-1); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Reshape; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxFlattenParser("Flatten", new OnnxFlattenParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h new file mode 100644 index 00000000000..cacecc37586 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_flatten_parser.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_FLATTEN_PARSER_H +#define MS_ONNX_FLATTEN_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxFlattenParser : public OnnxNodeParser { + public: + OnnxFlattenParser() : OnnxNodeParser("Fatten") {} + + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_FLATTEN_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc new file mode 100644 index 00000000000..b9afef1a250 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxGatherParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::GatherT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto& attribute_name = onnx_node_attr.name(); + if (attribute_name == "axis") { + attr->axis = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Gather; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxGatherParser("Gather", new OnnxGatherParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h new file mode 100644 index 00000000000..ef2d306f596 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_gather_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_GATHER_PARSER_H +#define MS_ONNX_GATHER_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxGatherParser : public OnnxNodeParser { + public: + OnnxGatherParser() : OnnxNodeParser("Gather") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_GATHER_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc new file mode 100644 index 00000000000..0590a291e63 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + unique_ptr attr(new schema::LrnT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto& attribute_name = onnx_node_attr.name(); + if (attribute_name == "size") { + attr->size = static_cast(onnx_node_attr.i()); + } else if (attribute_name == "alpha") { + attr->alpha = onnx_node_attr.f(); + } else if (attribute_name == "beta") { + attr->beta = onnx_node_attr.f(); + } else if (attribute_name == "bias") { + attr->bias = onnx_node_attr.f(); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Lrn; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxLrnxParser("Lrn", new OnnxLrnParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h new file mode 100644 index 00000000000..e3b15045a2d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_lrn_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_LRN_PARSER_H +#define MS_ONNX_LRN_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxLrnParser : public OnnxNodeParser { + public: + OnnxLrnParser() : OnnxNodeParser("Lrn") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_LRN_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc new file mode 100644 index 00000000000..857d38e2076 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::MatMulT()); + float alpha = 1.0f; + float beta = 1.0f; + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "transA") { + attr->transposeA = static_cast(onnx_node_attr.i()); + } else if (attribute_name == "transB") { + attr->transposeB = static_cast(onnx_node_attr.i()); + } else if (attribute_name == "alpha") { + alpha = onnx_node_attr.f(); + } else if (attribute_name == "beta") { + beta = onnx_node_attr.f(); + } + } + if (alpha != 1 || beta != 1) { + // MS_LOGE("not support alpha * A * B + beta * C"); + return RET_PARAM_INVALID; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_MatMul; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h new file mode 100644 index 00000000000..9c7565ded04 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_MATMUL_PARSER_H +#define MS_ONNX_MATMUL_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxMatmulParser : public OnnxNodeParser { + public: + OnnxMatmulParser() : OnnxNodeParser("MatMul") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_MATMUL_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc new file mode 100755 index 00000000000..2f06beb0a5c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -0,0 +1,512 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h" +#include "tools/common/graph_util.h" +#include "src/common/utils.h" + +namespace mindspore { +namespace lite { +OnnxModelParser::OnnxModelParser() = default; +OnnxModelParser::~OnnxModelParser() = default; + +static const std::unordered_map TYPE_MAP = { + {onnx::TensorProto_DataType_INT8, mindspore::kNumberTypeInt8}, + {onnx::TensorProto_DataType_UINT8, mindspore::kNumberTypeUInt8}, + {onnx::TensorProto_DataType_INT16, mindspore::kNumberTypeInt16}, + {onnx::TensorProto_DataType_INT32, mindspore::kNumberTypeInt32}, + {onnx::TensorProto_DataType_UINT32, mindspore::kNumberTypeUInt32}, + {onnx::TensorProto_DataType_INT64, mindspore::kNumberTypeInt64}, + {onnx::TensorProto_DataType_FLOAT16, mindspore::kNumberTypeFloat16}, + {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat}}; + +TypeId OnnxModelParser::GetDateTypeFromOnnx(onnx::TensorProto_DataType onnx_type) { + auto iter = TYPE_MAP.find(onnx_type); + if (iter == TYPE_MAP.end()) { + return kTypeUnknown; + } + return iter->second; +} + +std::vector OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value) { + std::vector dims; + const auto shape_info = onnx_value.type().tensor_type().shape(); + for (const auto &it : onnx_value.type().tensor_type().shape().dim()) { + dims.emplace_back(it.dim_value()); + } + return dims; +} + +STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model) { + std::unique_ptr onnx_file(new (std::nothrow) char[PATH_MAX]{0}); + if (realpath(modelFile.c_str(), onnx_file.get()) == nullptr) { + // MS_LOGE("get realpath %s fail", modelFile.c_str()); + return RET_ERROR; + } + int fd = open(onnx_file.get(), O_RDONLY); + google::protobuf::io::FileInputStream input(fd); + google::protobuf::io::CodedInputStream code_input(&input); + code_input.SetTotalBytesLimit(INT_MAX, 536870912); + bool ret = onnx_model->ParseFromCodedStream(&code_input); + if (!ret) { + // MS_LOGE("load onnx file failed"); + return RET_ERROR; + } + (void)close(fd); + return RET_OK; +} + +STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache) { + // MS_LOGD("set onnx constant tensors"); + for (const auto &onnx_const_value : onnx_graph.initializer()) { + std::vector dims; + std::copy(onnx_const_value.dims().begin(), onnx_const_value.dims().end(), std::back_inserter(dims)); + auto data_type = GetDateTypeFromOnnx(static_cast(onnx_const_value.data_type())); + if (data_type == kTypeUnknown) { + // MS_LOGE("not support onnx type %d", static_cast(onnx_const_value.data_type())); + return RET_ERROR; + } + std::unique_ptr tensor(new (std::nothrow) schema::TensorT); + if (tensor == nullptr) { + // MS_LOGE("new tensor failed"); + return RET_ERROR; + } + tensor->dataType = data_type; + tensor->format = schema::Format_NCHW; + for (const auto &it : dims) { + tensor->dims.emplace_back(it); + } + tensor->nodeType = schema::NodeType_ValueNode; + if (CopyOnnxTensorData(onnx_const_value, tensor.get())) { + return RET_ERROR; + } + const auto index = tensor_cache->AddTensor(onnx_const_value.name(), tensor.release(), GRAPH_INPUT); + // MS_LOGD("add const tensor: %s, index %d", onnx_const_value.name().c_str(), index) + } + return RET_OK; +} + +STATUS OnnxModelParser::AddTensorCache(const onnx::ValueInfoProto &proto, schema::TensorT *tensor) { + auto data_type = GetDateTypeFromOnnx(static_cast(proto.type().tensor_type().elem_type())); + if (data_type == kTypeUnknown) { + // MS_LOGE("not support onnx type %d", + // static_cast(proto.type().tensor_type().elem_type())); + return RET_ERROR; + } + tensor->dataType = data_type; + tensor->dims = GetDimsFromOnnxValue(proto); + tensor->format = schema::Format_NCHW; + tensor->nodeType = schema::NodeType_ValueNode; + return RET_OK; +} + +STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, + TensorCache *tensor_cache) { + for (const auto &input_value : onnx_graph.input()) { + auto ret = tensor_cache->FindTensor(input_value.name()); + if (ret < 0) { + std::unique_ptr tensor(new schema::TensorT); + if (AddTensorCache(input_value, tensor.get())) { + return RET_ERROR; + } + auto tensor_index = tensor_cache->AddTensor(input_value.name(), tensor.release(), GRAPH_INPUT); + graph->inputIndex.emplace_back(static_cast(tensor_index)); + // MS_LOGD("input_value name: %s, graph input index: %d", input_value.name().c_str(), tensor_index); + } + } + return RET_OK; +} + +STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, + TensorCache *tensor_cache) { + for (const auto &output_value : onnx_graph.output()) { + std::unique_ptr tensor(new schema::TensorT); + if (AddTensorCache(output_value, tensor.get())) { + return RET_ERROR; + } + auto tensor_index = tensor_cache->AddTensor(output_value.name(), tensor.release(), OP_OUTPUT); + graph->outputIndex.emplace_back(tensor_index); + // MS_LOGD("output_value name: %s, graph output index: %d", output_value.name().c_str(), tensor_index); + } + return RET_OK; +} + +void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::MetaGraphT *graph, TensorCache *tensor_cache) { + std::unique_ptr dst_op_1(new schema::CNodeT); + dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0); + // dst_op_1->fmkType = FmkType_ONNX; + ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get()); + auto matmul_output_id = "Gemm_MatMul_" + onnx_node.output(0); + std::vector matmul_inputs{onnx_node.input(0), onnx_node.input(1)}; + std::vector matmul_outputs{matmul_output_id}; + SetOpInputIndex(matmul_inputs, dst_op_1.get(), onnx_node, tensor_cache); + SetOpOutputIndex(matmul_outputs, dst_op_1.get(), tensor_cache); + graph->nodes.emplace_back(std::move(dst_op_1)); + + std::unique_ptr dst_op_2(new schema::CNodeT); + dst_op_2->name = "Gemm_BiasAdd_" + onnx_node.output(0); + // dst_op_2->fmkType = FmkType_ONNX; + ParseOnnxNodeAttr(onnx_graph, onnx_node, "BiasAdd", dst_op_2.get()); + std::vector biasadd_inputs{matmul_output_id, onnx_node.input(2)}; + std::vector biasadd_outputs{onnx_node.output(0)}; + SetOpInputIndex(biasadd_inputs, dst_op_2.get(), onnx_node, tensor_cache); + SetOpOutputIndex(biasadd_outputs, dst_op_2.get(), tensor_cache); + graph->nodes.emplace_back(std::move(dst_op_2)); +} + +STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { + // convert GivenTensorFill node to a weight/bias tensor + auto ret = tensor_cache->FindTensor(onnx_node.output(0)); + if (ret < 0) { + std::unique_ptr tensor(new schema::TensorT); + std::vector shape; + auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), + [](const onnx::AttributeProto &attr) { return attr.name() == "shape"; }); + if (iter != onnx_node.attribute().end()) { + (void)shape.insert(shape.begin(), iter->ints().begin(), iter->ints().end()); + std::for_each(shape.begin(), shape.end(), [](int sh) { /*MS_LOGD("shape: %d", sh);*/ }); + } + tensor->dims = shape; + tensor->format = schema::Format_NUM_OF_FORMAT; + tensor->nodeType = schema::NodeType_ValueNode; + iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), + [](const onnx::AttributeProto &attr) { return attr.name() == "values"; }); + // copy GivenIntTensorFill node value to tensor + if (iter != onnx_node.attribute().end()) { + size_t data_count = 1; + std::for_each(shape.begin(), shape.end(), [&data_count](int dim) { data_count *= dim; }); + size_t data_size = 0; + if (onnx_node.op_type() == "Int8GivenIntTensorFill") { + // todo how to read onnx-ori-dataType + tensor->dataType = kNumberTypeInt32; + data_size = data_count * sizeof(int32_t) / sizeof(uint8_t); + tensor->data.resize(data_size); + void *tensorData = tensor->data.data(); + auto castedTensorData = static_cast(tensorData); + MS_ASSERT(castedTensorData != nullptr); + for (size_t i = 0; i < data_count; i++) { + castedTensorData[i] = int32_t(iter->ints().data()[i]); + } + } else if (onnx_node.op_type() == "Int8GivenTensorFill") { + // todo how to read onnx-ori-dataType + tensor->dataType = kNumberTypeUInt8; + // todo: add * sizof(string) + data_size = data_count; + tensor->data.resize(data_size); + // MS_LOGD("tensor data size %lu, s: %lu", data_size, sizeof(iter->s().data())); + if (memcpy_s(tensor->data.data(), data_size, iter->s().data(), data_size) != 0) { + // MS_LOGE("memcpy_s failed") + return RET_ERROR; + } + } else { + // MS_LOGE("unsupported data type %d", tensor->dataType); + return RET_ERROR; + } + } + auto index = tensor_cache->AddTensor(onnx_node.output(0), tensor.release(), GRAPH_INPUT); + // MS_LOGD("add given tensor: %d", index); + } + return RET_OK; +} + +STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *dst_op, + schema::TensorT *dst_tensor, + TensorCache *tensor_cache) { + // change op_type() to name(), that is unique + dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); + // dst_op->fmkType = FmkType_ONNX; + // MS_LOGD("onnx op name %s, dst op name: %s, input size %d", onnx_node.op_type().c_str(), dst_op->name.c_str(), + // onnx_node.input_size()); + // get the real op type + SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache); + auto status = ParseOnnxNodeAttr(onnx_graph, onnx_node, onnx_node.op_type(), dst_op); + if (status != RET_OK) { + // MS_LOGE("parser onnx node attr failed"); + return status; + } + // set op input index + std::vector node_inputs; + (void)node_inputs.insert(node_inputs.begin(), onnx_node.input().begin(), onnx_node.input().end()); + if (SetOpInputIndex(node_inputs, dst_op, onnx_node, tensor_cache)) { + // MS_LOGE("SetOpInputIndex failed"); + return RET_ERROR; + } + // set op output index + std::vector node_outputs; + (void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end()); + if (SetOpOutputIndex(node_outputs, dst_op, tensor_cache) != RET_OK) { + // MS_LOGE("SetOpOutputIndex failed"); + return RET_ERROR; + } + return RET_OK; +} + +void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache) { + MS_ASSERT(dst_op != nullptr); + MS_ASSERT(tensor_cache != nullptr); + std::vector quant_node_name; + quant_node_name.insert(quant_node_name.begin(), onnx_node.input().begin(), onnx_node.input().end()); + quant_node_name.insert(quant_node_name.end(), onnx_node.output().begin(), onnx_node.output().end()); + std::vector quant_node; + for (const auto &str : quant_node_name) { + for (auto &node : onnx_graph.node()) { + if (node.output(0) == str) { + quant_node.emplace_back(node); + break; + } + } + } + auto needQuantParams = size_t(onnx_node.input().size() + onnx_node.output().size()); + for (auto iter = onnx_node.input().begin(); iter != onnx_node.input().end(); iter++) { + if (IsContain(this->graphInputNames, *iter)) { + needQuantParams--; + } + } + size_t findQuantParams = 0; + for (const auto &node : quant_node) { + std::unique_ptr quant_param(new (std::nothrow) schema::QuantParamT()); + if (quant_param == nullptr) { + // MS_LOGE("new QuantParamT failed, node: %s", dst_op->name.c_str()); + return; + } + // std::unique_ptr quant_param_array(new (std::nothrow) QuantParamArrayT()); + if (quant_param == nullptr) { + // MS_LOGE("new QuantParamArrayT failed, node: %s", dst_op->name.c_str()); + return; + } + int argNum = 0; + for (const auto &onnx_node_attr : node.attribute()) { + if (onnx_node_attr.name() == "Y_scale") { + quant_param->scale = onnx_node_attr.f(); + argNum++; + } else if (onnx_node_attr.name() == "Y_zero_point") { + quant_param->zeroPoint = static_cast(onnx_node_attr.i()); + argNum++; + } + } + if (argNum != 2) { + quant_param->scale = FLT_MAX; + quant_param->zeroPoint = 0; + quant_param->min = FLT_MAX; + quant_param->max = FLT_MAX; + } + // quant_param_array->param.emplace_back(std::move(quant_param)); + dst_tensor->quantParams.emplace_back(std::move(quant_param)); + if (argNum == 2) { + findQuantParams++; + } + } + if (findQuantParams == needQuantParams) { + dst_op->quantType = schema::QuantType_AwareTrainning; + } else { + dst_op->quantType = schema::QuantType_QUANT_NONE; + } +} + +STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + const string &onnx_op_type, schema::CNodeT *dst_op) { + auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type); + if (node_parser == nullptr) { + // MS_LOGE("not find %s, node parser is nullptr", onnx_op_type.c_str()); + return RET_NULL_PTR; + } + return node_parser->Parse(onnx_graph, onnx_node, dst_op); +} + +STATUS OnnxModelParser::SetOpInputIndex(const std::vector &node_inputs, schema::CNodeT *dst_op, + const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { + schema::Format format = schema::Format_MAX; + for (const auto &onnx_node_attr : onnx_node.attribute()) { + if (onnx_node_attr.name() == "order") { + if (onnx_node_attr.s() == "NHWC") { + format = schema::Format_NHWC; + } else { + // MS_LOGE("Unsupported format: %s", onnx_node_attr.s().c_str()); + return RET_ERROR; + } + } + } + for (const auto &onnx_node_input : node_inputs) { + auto index = tensor_cache->FindTensor(onnx_node_input); + if (index < 0) { + std::unique_ptr tensor(new schema::TensorT); + index = tensor_cache->AddTensor(onnx_node_input, tensor.release(), OP_OUTPUT); + } + if (format != schema::Format_MAX) { + auto inTensor = tensor_cache->GetCachedTensor().at(index); + inTensor->format = format; + } + // MS_LOGD("node: %s, input index: %d", onnx_node_input.c_str(), index); + dst_op->inputIndex.emplace_back(index); + } + return RET_OK; +} + +STATUS OnnxModelParser::SetOpOutputIndex(const std::vector &node_outputs, schema::CNodeT *dst_op, + TensorCache *tensor_cache) { + for (const auto &onnx_node_output : node_outputs) { + auto index = tensor_cache->FindTensor(onnx_node_output); + if (index < 0) { + std::unique_ptr tensor(new schema::TensorT); + index = tensor_cache->AddTensor(onnx_node_output, tensor.release(), OP_OUTPUT); + } + // MS_LOGD("node: %s, input index: %d", onnx_node_output.c_str(), index); + dst_op->outputIndex.emplace_back(index); + } + return RET_OK; +} + +STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, + schema::TensorT *tensor) { + size_t data_count = 1; + std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; }); + size_t data_size = 0; + const void *tensor_data = nullptr; + switch (tensor->dataType) { + case kNumberTypeFloat: + data_size = data_count * sizeof(float); + if (onnx_const_value.float_data_size() == 0) { + tensor_data = onnx_const_value.raw_data().data(); + } else { + tensor_data = onnx_const_value.float_data().data(); + } + break; + case kNumberTypeInt32: + data_size = data_count * sizeof(int); + if (onnx_const_value.int32_data_size() == 0) { + tensor_data = onnx_const_value.raw_data().data(); + } else { + tensor_data = onnx_const_value.int32_data().data(); + } + break; + case kNumberTypeInt64: + data_size = data_count * sizeof(int64_t); + if (onnx_const_value.int64_data_size() == 0) { + tensor_data = onnx_const_value.raw_data().data(); + } else { + tensor_data = onnx_const_value.int64_data().data(); + } + break; + case kNumberTypeUInt8: + case kNumberTypeInt8: + data_size = data_count * sizeof(uint8_t); + tensor_data = onnx_const_value.raw_data().data(); + break; + default: + // MS_LOGE("unsupported data type %d", tensor->dataType); + return RET_ERROR; + } + tensor->data.resize(data_size); + if (memcpy_s(static_cast(tensor->data.data()), data_size, tensor_data, data_size) != 0) { + // MS_LOGE("memcpy_s failed") + return RET_ERROR; + } + return RET_OK; +} + +STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef) { + std::vector tensors = tensor_cache.GetCachedTensor(); + for (auto iter : tensors) { + std::unique_ptr temp(iter); + graphDef->allTensors.emplace_back(move(temp)); + } + return RET_OK; +} + +void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph) { + this->graphInputNames.clear(); + this->graphConstNames.clear(); + for (auto &onnx_const : onnx_graph.initializer()) { + this->graphConstNames.emplace_back(onnx_const.name()); + } + for (auto &onnx_input : onnx_graph.input()) { + if (!IsContain(this->graphConstNames, onnx_input.name())) { + this->graphInputNames.emplace_back(onnx_input.name()); + } + } +} + +MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, const std::string &weightFile) { + if (ValidateFileStr(modelFile, ".onnx") != RET_OK) { + // MS_LOGE("Input illegal: modelFile must be *.onnx"); + return nullptr; + } + std::unique_ptr dst_graph(new schema::MetaGraphT()); + onnx::ModelProto onnx_model; + if (ReadOnnxModelFromBinary(modelFile, &onnx_model) != RET_OK) { + // MS_LOGE("read onnx model fail"); + return nullptr; + } + const onnx::GraphProto &onnx_graph = onnx_model.graph(); + // MS_LOGI("model producer name: %s, graph name: %s", onnx_model.producer_name().c_str(), onnx_graph.name().c_str()); + TensorCache tensor_cache; + dst_graph->name = onnx_graph.name(); + // find out input names and const names + FindGraphInputAndConst(onnx_graph); + // set const tensor + if (SetGraphConstTensor(onnx_graph, &tensor_cache)) { + // MS_LOGE("SetGraphConstTensor failed"); + return nullptr; + } + // init onnx model graph input tensor + if (SetGraphInputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) { + // MS_LOGE("SetGraphInputTensor failed"); + return nullptr; + } + // init onnx model graph output tensor + if (SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) { + // MS_LOGE("SetGraphOutputTensor failed"); + return nullptr; + } + // init op node input/output tensor, and dst_op attr + for (const auto &onnx_node : onnx_graph.node()) { + if (onnx_node.op_type() == "Gemm") { + ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache); + continue; + } else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") { + auto status = ParseOnnxGivenFillNode(onnx_node, &tensor_cache); + if (status != RET_OK) { + // MS_LOGE("ParseOnnxGivenFillNode failed: %d", status); + return nullptr; + } + continue; + } + + std::unique_ptr dst_op(new schema::CNodeT); + std::unique_ptr dst_tensor(new schema::TensorT); + if (ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache)) { + // MS_LOGE("parse node %s failed", onnx_node.op_type().c_str()) + return nullptr; + } + dst_graph->nodes.emplace_back(std::move(dst_op)); + } + SetAllTensors(tensor_cache, dst_graph.get()); + dst_graph->mempoolSize = 0; + dst_graph->name = GetModelName(modelFile); + return dst_graph.release(); +// return Fb2Anf(dst_graph.release()); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h new file mode 100644 index 00000000000..f179082a705 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -0,0 +1,80 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_MODEL_PARSER_H +#define MS_ONNX_MODEL_PARSER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "securec/include/securec.h" +#include "mindspore/lite/tools/converter/model_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" +#include "tools/common/tensor_util.h" + +namespace mindspore { +namespace lite { +class OnnxModelParser : public ModelParser { + public: + OnnxModelParser(); + virtual ~OnnxModelParser(); + MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) override; + + private: + TypeId GetDateTypeFromOnnx(onnx::TensorProto_DataType onnx_type); + std::vector GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); + STATUS ReadOnnxModelFromBinary(const std::string &modelFile, google::protobuf::Message *model_proto); + STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache); + STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache); + STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache); + STATUS AddTensorCache(const onnx::ValueInfoProto &proto, schema::TensorT *tensor); + STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *dst_op, + schema::TensorT *dst_tensor, TensorCache *tensor_cache); + void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::MetaGraphT *graph, + TensorCache *tensor_cache); + STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); + STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + const string &onnx_op_type, schema::CNodeT *dst_op); + void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, + schema::TensorT *dst_tensor, TensorCache *tensor_cache); + STATUS SetOpInputIndex(const std::vector &node_inputs, + schema::CNodeT *dst_op, + const onnx::NodeProto &onnx_node, + TensorCache *tensor_cache); + STATUS SetOpOutputIndex(const std::vector &node_outputs, schema::CNodeT *dst_op, TensorCache *tensor_cache); + STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor); + STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef); + void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph); + + private: + std::vector graphInputNames; + std::vector graphConstNames; +}; +} // namespace lite +} // namespace mindspore + +#endif // MS_ONNX_MODEL_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc new file mode 100644 index 00000000000..cd225232cf8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" + +namespace mindspore { +namespace lite { +schema::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr) { + if (onnx_node_attr.s() == "NOTSET") { + return schema::PadMode_NOTSET; + } else if (onnx_node_attr.s() == "SAME_UPPER" || onnx_node_attr.s() == "SAME_LOWER") { + return schema::PadMode_SAME; + } else if (onnx_node_attr.s() == "VALID") { + return schema::PadMode_VALID; + } else { + // MS_LOGE("unsupported padMode"); + return schema::PadMode_NOTSET; + } +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h new file mode 100644 index 00000000000..a479d9b0335 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_NODE_PARSER_H +#define MS_ONNX_NODE_PARSER_H + +#include +#include "google/protobuf/message.h" +#include "mindspore/lite/tools/converter/proto/onnx.pb.h" +#include "tools/common/node_util.h" +#include "mindspore/lite/schema/inner/model_generated.h" + +// using namespace std; + +namespace mindspore { +namespace lite { +class OnnxNodeParser { + public: + explicit OnnxNodeParser(const std::string &nodeName) : name(nodeName) {} + virtual ~OnnxNodeParser() = default; + virtual STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) = 0; + + protected: + schema::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr); + const std::string &name; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_NODE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.cc new file mode 100644 index 00000000000..daefc9964bd --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" +#include + +namespace mindspore { +namespace lite { +OnnxNodeParserRegistry::OnnxNodeParserRegistry() = default; + +OnnxNodeParserRegistry::~OnnxNodeParserRegistry() = default; + +OnnxNodeParserRegistry *OnnxNodeParserRegistry::GetInstance() { + static OnnxNodeParserRegistry instance; + return &instance; +} + +OnnxNodeParser *OnnxNodeParserRegistry::GetNodeParser(const std::string &name) { + auto it = parsers.find(name); + if (it != parsers.end()) { + return it->second; + } + for (auto const &i : parsers) { + if (name.find(i.first) != std::string::npos) { + return i.second; + } + } + return nullptr; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h new file mode 100644 index 00000000000..f4781467df2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_OP_REGISTRY_H +#define MS_ONNX_OP_REGISTRY_H + +#include +#include +#include "mindspore/lite/tools/converter/proto/onnx.pb.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" + +namespace mindspore { +namespace lite { +class OnnxNodeParserRegistry { + public: + OnnxNodeParserRegistry(); + + virtual ~OnnxNodeParserRegistry(); + + static OnnxNodeParserRegistry *GetInstance(); + OnnxNodeParser *GetNodeParser(const std::string &name); + + std::unordered_map parsers; +}; + +class OnnxNodeRegistrar { + public: + OnnxNodeRegistrar(const std::string &name, OnnxNodeParser *parser) { + OnnxNodeParserRegistry::GetInstance()->parsers[name] = parser; + } +}; +} // namespace lite +} // namespace mindspore + +#endif // MS_ONNX_OP_REGISTRY_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc new file mode 100644 index 00000000000..c200d14f3ab --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + unique_ptr attr(new schema::PadT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "pads") { + const int size = onnx_node_attr.ints_size(); + attr->paddings.resize(size); + for (int i = 0; i < size / 2; ++i) { + attr->paddings[i * 2] = static_cast(onnx_node_attr.ints(i)); + attr->paddings[i * 2 + 1] = static_cast(onnx_node_attr.ints(i + size / 2)); + } + } else if (attribute_name == "mode") { + const auto &mode = onnx_node_attr.s(); + if (mode == "constant") { + attr->paddingmode = schema::PaddingMode_CONSTANT; + } else if (mode == "reflect") { + attr->paddingmode = schema::PaddingMode_REFLECT; + } else if (mode == "edge") { + attr->paddingmode = schema::PaddingMode_SYMMETRIC; + } + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Pad; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxPadParser("Pad", new OnnxPadParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h new file mode 100644 index 00000000000..ba2e54bc593 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_LRN_PARSER_H +#define MS_ONNX_LRN_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxPadParser : public OnnxNodeParser { + public: + OnnxPadParser() : OnnxNodeParser("Pad") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_LRN_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc new file mode 100644 index 00000000000..70e61366905 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + unique_ptr attr(new schema::PoolingT()); + + const auto &pool_type = onnx_node.op_type(); + if (pool_type == "MaxPool") { + attr->poolingMode = schema::PoolMode_MAX_POOLING; + attr->global = false; + } else if (pool_type == "AveragePool") { + attr->poolingMode = schema::PoolMode_MEAN_POOLING; + attr->global = false; + } else if (pool_type == "GlobalMaxPool") { + attr->poolingMode = schema::PoolMode_MAX_POOLING; + attr->global = true; + } else if (pool_type == "GlobalAveragePool") { + attr->poolingMode = schema::PoolMode_MEAN_POOLING; + attr->global = true; + } else { + // MS_LOGE("Pooling param`s PoolingMode is not MAX either AVE. MindSpore support MAX and AVE only."); + return RET_ERROR; + } + + attr->roundMode = schema::RoundMode_FLOOR; + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "kernel_shape") { + if (onnx_node_attr.ints_size() == 2) { + attr->windowW = static_cast(onnx_node_attr.ints(0)); + attr->windowH = static_cast(onnx_node_attr.ints(1)); + } + } + if (attribute_name == "strides") { + if (onnx_node_attr.ints_size() == 2) { + attr->strideW = static_cast(onnx_node_attr.ints(0)); + attr->strideH = static_cast(onnx_node_attr.ints(1)); + } + } + if (attribute_name == "auto_pad") { + MS_ASSERT(false); + } + if (attribute_name == "pads") { + if (onnx_node_attr.ints_size() == 4) { + attr->padMode = schema::PadMode_CAFFE; + attr->padUp = static_cast(onnx_node_attr.ints(0)); + attr->padDown = static_cast(onnx_node_attr.ints(1)); + attr->padLeft = static_cast(onnx_node_attr.ints(0)); + attr->padRight = static_cast(onnx_node_attr.ints(1)); + } + } + if (attribute_name == "ceil_mode") { + MS_ASSERT(false); // todo (h00500767) + attr->roundMode = schema::RoundMode_CEIL; + } + if (attribute_name == "dilations") { + MS_ASSERT(false); // todo pooling op not support dilations now + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Pooling; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxMaxPoolParser("MaxPool", new OnnxPoolParser()); +OnnxNodeRegistrar g_onnxAveragePoolParser("AveragePool", new OnnxPoolParser()); +OnnxNodeRegistrar g_onnxGlobalAveragePoolParser("GlobalAveragePool", new OnnxPoolParser()); +OnnxNodeRegistrar g_onnxGlobalMaxPoolParser("GlobalMaxPool", new OnnxPoolParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h new file mode 100644 index 00000000000..ce439cf3f16 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_POOL_PARSER_H +#define MS_ONNX_POOL_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxPoolParser : public OnnxNodeParser { + public: + OnnxPoolParser() : OnnxNodeParser("Pool") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_POOL_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc new file mode 100644 index 00000000000..c4453739035 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::ReduceT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "axes") { + const int &size = onnx_node_attr.ints_size(); + for (int i = 0; i < size; ++i) { + attr->axes.push_back(onnx_node_attr.ints(i)); + } + } else if (attribute_name == "keepdims") { + attr->keepDims = static_cast(onnx_node_attr.i()); + } + } + const auto &type = onnx_node.op_type(); + if (type == "ReduceMean") { + attr->mode = schema::ReduceMode_ReduceMean; + } else if (type == "ReduceMax") { + attr->mode = schema::ReduceMode_ReduceMax; + } else if (type == "ReduceMin") { + attr->mode = schema::ReduceMode_ReduceMin; + } else if (type == "ReduceSum") { + attr->mode = schema::ReduceMode_ReduceSum; + } else { + // MS_LOGE("unsupoort type"); + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Reduce; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxReduceMeanParser("ReduceMean", new OnnxReduceParser()); +OnnxNodeRegistrar g_onnxReduceMaxParser("ReduceMax", new OnnxReduceParser()); +OnnxNodeRegistrar g_onnxReduceMinParser("ReduceMin", new OnnxReduceParser()); +OnnxNodeRegistrar g_onnxReduceProdParser("ReduceProd", new OnnxReduceParser()); +OnnxNodeRegistrar g_onnxReduceSumParser("ReduceSum", new OnnxReduceParser()); +OnnxNodeRegistrar g_onnxReduceSumSquareParser("ReduceSumSquare", new OnnxReduceParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h new file mode 100644 index 00000000000..9b19d37062c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reduce_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_REDUCE_PARSER_H +#define MS_ONNX_REDUCE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxReduceParser : public OnnxNodeParser { + public: + OnnxReduceParser() : OnnxNodeParser("Reduce") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_REDUCE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc new file mode 100644 index 00000000000..9947b5fa8cb --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h" +#include "securec/include/securec.h" +namespace mindspore { +namespace lite { +STATUS OnnxReluParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + unique_ptr attr(new schema::ActivationT()); + const auto &relu_type = onnx_node.op_type(); + if (relu_type == "Relu") { + attr->type = schema::ActivationType_RELU; + } else if (relu_type == "LeakyRelu") { + attr->type = schema::ActivationType_LEAKY_RELU; + } + + if (op != nullptr) { + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +STATUS OnnxPReluParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + if (onnx_node.input_size() != 2) { + // MS_LOGE("input num is not 2") + return RET_PARAM_INVALID; + } + unique_ptr attr(new schema::PreluT()); + std::vector params; + for (int i = 0; i < onnx_node.input_size(); ++i) { + const auto &input_name = onnx_node.input(i); + for ( const auto &it : onnx_graph.initializer() ) { + if (it.name() == "input_name") { + params.push_back(it); + break; + } + } + } + const onnx::TensorProto *slope = ¶ms[0]; + if (slope == nullptr) { + // MS_LOGE("input error") + return RET_PARAM_INVALID; + } + const auto slope_raw_data = reinterpret_cast(slope->raw_data().data()); + const int64_t slope_size = slope->raw_data().size() / sizeof(float); + if (memcpy_s(attr->slope.data(), slope_size * sizeof(float), slope_raw_data, slope_size * sizeof(float)) != 0) { + // MS_LOGE("memcpy_s failed") + return RET_ERROR; + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Prelu; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxReluParser("Relu", new OnnxReluParser()); +OnnxNodeRegistrar g_onnxLeakyReluParser("LeakyRelu", new OnnxLeakeyReluParser()); +OnnxNodeRegistrar g_onnxPReluParser("Prelu", new OnnxPReluParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h new file mode 100644 index 00000000000..a3750c21be7 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_relu_parser.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_RELU_PARSER_H +#define MS_ONNX_RELU_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxReluParser : public OnnxNodeParser { + public: + OnnxReluParser() : OnnxNodeParser("Relu") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; + +class OnnxLeakeyReluParser : public OnnxReluParser { + public: + OnnxLeakeyReluParser() : OnnxReluParser() {} +}; + +class OnnxPReluParser : public OnnxNodeParser { + public: + OnnxPReluParser() : OnnxNodeParser("Prelu") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_RELU_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc new file mode 100644 index 00000000000..c02c428494c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::ReshapeT()); + attr->format = schema::Format_NHWC; + + std::vector params; + for (int i = 0; i < onnx_node.input_size(); ++i) { + const auto &input_name = onnx_node.input(i); + for (const auto &it : onnx_graph.initializer()) { + if (it.name() == input_name) { + params.emplace_back(it); + break; + } + } + } + if (params.empty()) { + return RET_OK; + } + if (params.size() != 1) { + // MS_LOGE("input num is ,not equal 1", params.size()) + return RET_PARAM_INVALID; + } + + auto pre_shape = params[0]; + for (int i = 0; i < pre_shape.dims_size(); ++i) { + attr->shape.emplace_back(params[0].dims(i)); + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Reshape; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxReshapeParser("Reshape", new OnnxReshapeParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h new file mode 100644 index 00000000000..5c0d673dfcd --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_RESHAPE_PARSER_H +#define MS_ONNX_RESHAPE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxReshapeParser : public OnnxNodeParser { + public: + OnnxReshapeParser() : OnnxNodeParser("Reshape") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_RESHAPE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc new file mode 100644 index 00000000000..d85740b3b51 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxShapeParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Shape; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxShapeParser("Shape", new OnnxShapeParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h new file mode 100644 index 00000000000..27073aa66d4 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_shape_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_SHAPE_PARSER_H +#define MS_ONNX_SHAPE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxShapeParser : public OnnxNodeParser { + public: + OnnxShapeParser() : OnnxNodeParser("Shape") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_SHAPE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.cc new file mode 100644 index 00000000000..e275e092b96 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxSigmoidParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::ActivationT()); + attr->type = schema::ActivationType_SIGMOID; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxSigmoodParser("Sigmoid", new OnnxSigmoidParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h new file mode 100644 index 00000000000..55f8664965b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_sigmoid_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_SIGMOID_PARSER_H +#define MS_ONNX_SIGMOID_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxSigmoidParser : public OnnxNodeParser { + public: + OnnxSigmoidParser() : OnnxNodeParser("Sigmoid") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_SIGMOID_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc new file mode 100644 index 00000000000..83f9c49f9fe --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::SliceT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto& attribute_name = onnx_node_attr.name(); + if (attribute_name == "starts") { + const int size = onnx_node_attr.ints_size(); + for (int i = 0; i < size; ++i) { + attr->begin.emplace_back(static_cast(onnx_node_attr.ints(i))); + } + } else if (attribute_name == "ends") { + const int size = onnx_node_attr.ints_size(); + for (int i = 0; i < size; ++i) { + attr->size.emplace_back(static_cast(onnx_node_attr.ints(i))); + } + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Slice; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxSliceParser("Slice", new OnnxSliceParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h new file mode 100644 index 00000000000..6a45db1f31d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_slice_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_SLICE_PARSER_H +#define MS_ONNX_SLICE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxSliceParser : public OnnxNodeParser { + public: + OnnxSliceParser() : OnnxNodeParser("Slice") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_SLICE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc new file mode 100644 index 00000000000..229cc7848a2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxSoftMaxParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::SoftMaxT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto& attribute_name = onnx_node_attr.name(); + if (attribute_name == "axis") { + attr->axis = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_SoftMax; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxSoftMaxParser("Softmax", new OnnxSoftMaxParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h new file mode 100644 index 00000000000..822944ea5ea --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_softmax_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_SOFTMAX_PARSER_H +#define MS_ONNX_SOFTMAX_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxSoftMaxParser : public OnnxNodeParser { + public: + OnnxSoftMaxParser() : OnnxNodeParser("Softmax") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_SOFTMAX_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc new file mode 100644 index 00000000000..549e20329f6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxSPaceToDepthParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::SpaceToDepthT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "blocksize") { + attr->blockSize = static_cast(onnx_node_attr.i()); + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_SpaceToDepth; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxSpaceToDepthParser("SpaceToDepth", new OnnxSPaceToDepthParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h new file mode 100644 index 00000000000..2a47a967585 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_SPACE_TO_DEPTH_PARSER_H +#define MS_ONNX_SPACE_TO_DEPTH_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxSPaceToDepthParser : public OnnxNodeParser { + public: + OnnxSPaceToDepthParser() : OnnxNodeParser("SpaceToDepth") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_SPACE_TO_DEPTH_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc new file mode 100644 index 00000000000..f462d6091e4 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::SqueezeT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "axes") { + for (int i = 0; i < onnx_node_attr.ints().size(); ++i) { + attr->axis.emplace_back(onnx_node_attr.ints(i)); + } + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Squeeze; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxSqueezeParser("Squeeze", new OnnxSqueezeParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h new file mode 100644 index 00000000000..f8e3050809f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_squeeze_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_SQUEEZE_PARSER_H +#define MS_ONNX_SQUEEZE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxSqueezeParser : public OnnxNodeParser { + public: + OnnxSqueezeParser() : OnnxNodeParser("Squeeze") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_SQUEEZE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc new file mode 100644 index 00000000000..b6839b958ab --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Tile; + op->primitive->value.value = nullptr; + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxTileParser("Tile", new OnnxTileParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h new file mode 100644 index 00000000000..f09811e099a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_TILE_PARSER_H +#define MS_ONNX_TILE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxTileParser : public OnnxNodeParser { + public: + OnnxTileParser() : OnnxNodeParser("Tile") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_ARGMAX_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc new file mode 100644 index 00000000000..f16a2d02789 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::TransposeT()); + attr->conjugate = false; + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "axes") { + attr->perm.resize(onnx_node_attr.ints_size()); + for (int i = 0; i < onnx_node_attr.ints_size(); ++i) { + attr->perm[i] = onnx_node_attr.ints(i); + } + } + if (attribute_name == "perm") { + attr->perm.resize(onnx_node_attr.ints_size()); + for (int i = 0; i < onnx_node_attr.ints_size(); ++i) { + attr->perm[i] = onnx_node_attr.ints(i); + } + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Transpose; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxTransposeParser("Transpose", new OnnxTransposeParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h new file mode 100644 index 00000000000..e87279b43b3 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_transpose_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_TRANSPOSE_PARSER_H +#define MS_ONNX_TRANSPOSE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxTransposeParser : public OnnxNodeParser { + public: + OnnxTransposeParser() : OnnxNodeParser("Transpose") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_TRANSPOSE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.cc new file mode 100644 index 00000000000..90b572d5659 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::UpsampleT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "mode") { + attr->mode = onnx_node_attr.s(); + } else if (attribute_name == "scales") { + for (int i = 0; i < onnx_node_attr.floats_size(); ++i) { + attr->scales[i] = onnx_node_attr.floats(i); + } + } + } + // to do + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Upsample; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxUpsampleParser("Upsample", new OnnxUpsampleParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h new file mode 100644 index 00000000000..2b2a6a3a6f0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_UPSAMPLE_PARSER_H +#define MS_ONNX_UPSAMPLE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxUpsampleParser : public OnnxNodeParser { + public: + OnnxUpsampleParser() : OnnxNodeParser("Upsample") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_UPSAMPLE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc new file mode 100644 index 00000000000..8ed288e97ba --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxUnSqueezeParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + unique_ptr attr(new schema::UnsqueezeT()); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "axes") { + for (int i = 0; i < onnx_node_attr.ints().size(); ++i) { + attr->axis.emplace_back(onnx_node_attr.ints(i)); + } + } + } + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Unsqueeze; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxUnsqueezeParser("Unsqueeze", new OnnxUnSqueezeParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h new file mode 100644 index 00000000000..231d3ef2a90 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_UNSQUEEZE_PARSER_H +#define MS_ONNX_UNSQUEEZE_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxUnSqueezeParser : public OnnxNodeParser { + public: + OnnxUnSqueezeParser() : OnnxNodeParser("Unsqueeze") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_UNSQUEEZE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.cc new file mode 100644 index 00000000000..bb92abe567d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h" + +namespace mindspore { +namespace lite { +STATUS OnnxUnusefulNodeParser::Parse(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + if (op != nullptr) { + op->primitive = std::make_unique(); + if (onnx_node.op_type() == "Int8Quantize") { + op->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize; + op->primitive->value.value = new (std::nothrow) schema::OnnxInt8QuantizeT; + } else if (onnx_node.op_type() == "Int8Dequantize") { + op->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; + op->primitive->value.value = new (std::nothrow) schema::OnnxInt8DequantizeT; + } else { + // MS_LOGE("Unsupported nodeType: %s", onnx_node.op_type().c_str()); + return RET_ERROR; + } + if (op->primitive->value.value == nullptr) { + // MS_LOGE("new %s attr value failed", onnx_node.op_type().c_str()); + return RET_ERROR; + } + } else { + // MS_LOGE("Input opDef is nullptr"); + return RET_PARAM_INVALID; + } + return RET_OK; +} + +OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxUnusefulNodeParser()); +OnnxNodeRegistrar g_onnxInt8DequantizeParser("Int8Dequantize", new OnnxUnusefulNodeParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h new file mode 100644 index 00000000000..6e002254f03 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_ONNX_UNUSEFUL_PARSER_H +#define MS_ONNX_UNUSEFUL_PARSER_H + +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h" +#include "mindspore/lite/tools/converter/parser/onnx/onnx_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class OnnxUnusefulNodeParser : public OnnxNodeParser { + public: + OnnxUnusefulNodeParser() : OnnxNodeParser("UnusefulNode") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MS_ONNX_UNUSEFUL_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt b/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt new file mode 100644 index 00000000000..03f9b3670b8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB_RECURSE TFLITE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + *.cc + ) +add_library(tflite_parser_mid OBJECT + ${TFLITE_SRC_LIST} + ) diff --git a/mindspore/lite/tools/converter/parser/tflite/schema.fbs b/mindspore/lite/tools/converter/parser/tflite/schema.fbs new file mode 100644 index 00000000000..bdb07377a8c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/schema.fbs @@ -0,0 +1,926 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace tflite; + +// This corresponds to the version. +file_identifier "TFL3"; +// File extension of any written files. +file_extension "tflite"; + +// IMPORTANT: All new members of tables, enums and unions must be added at the +// end to ensure backwards compatibility. + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, + STRING = 5, + BOOL = 6, + INT16 = 7, + COMPLEX64 = 8, + INT8 = 9, +} + +// Custom quantization parameters for experimenting with new quantization +// techniques. +table CustomQuantization { + custom:[ubyte] (force_align: 16); +} + +// Represents a specific quantization technique's parameters. +union QuantizationDetails { + CustomQuantization, +} + +// Parameters for converting a quantized tensor back to float. +table QuantizationParameters { + // These four parameters are the asymmetric linear quantization parameters. + // Given a quantized value q, the corresponding float value f should be: + // f = scale * (q - zero_point) + // For other quantization types, the QuantizationDetails below is used. + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; // For dequantizing the tensor's values. + zero_point:[long]; + + // If this is not none, the other quantization parameters (i.e. min, max, + // scale, zero_point fields above) are ignored and the value of the + // QuantizationDetails union should be used. + details:QuantizationDetails; + + // Specifies the dimension of the Tensor's shape that the scales and + // zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1] + // with quantization params: + // scale=[1.0, 2.0, 3.0], zero_point=[1, 2, 3], quantization_dimension=1 + // will be quantized across the second dimension of t. + // t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1 + // t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2 + // t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3 + quantized_dimension:int; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, height, width, number of channels] (That's + // Tensorflow's NHWC). + shape:[int]; + type:TensorType; + // An index that refers to the buffers table at the root of the model. Or, + // if there is no data buffer associated (i.e. intermediate results), then + // this is 0 (which refers to an always existent empty buffer). + // + // The data_buffer itself is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k]. + buffer:uint; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. + + is_variable:bool = false; +} + +// A list of builtin operators. Builtin operators are slightly faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. +enum BuiltinOperator : byte { + ADD = 0, + AVERAGE_POOL_2D = 1, + CONCATENATION = 2, + CONV_2D = 3, + DEPTHWISE_CONV_2D = 4, + DEPTH_TO_SPACE = 5, + DEQUANTIZE = 6, + EMBEDDING_LOOKUP = 7, + FLOOR = 8, + FULLY_CONNECTED = 9, + HASHTABLE_LOOKUP = 10, + L2_NORMALIZATION = 11, + L2_POOL_2D = 12, + LOCAL_RESPONSE_NORMALIZATION = 13, + LOGISTIC = 14, + LSH_PROJECTION = 15, + LSTM = 16, + MAX_POOL_2D = 17, + MUL = 18, + RELU = 19, + // NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed + // since different model developers use RELU1 in different ways. Never + // create another op called RELU1. + RELU_N1_TO_1 = 20, + RELU6 = 21, + RESHAPE = 22, + RESIZE_BILINEAR = 23, + RNN = 24, + SOFTMAX = 25, + SPACE_TO_DEPTH = 26, + SVDF = 27, + TANH = 28, + // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS + CONCAT_EMBEDDINGS = 29, + SKIP_GRAM = 30, + CALL = 31, + CUSTOM = 32, + EMBEDDING_LOOKUP_SPARSE = 33, + PAD = 34, + UNIDIRECTIONAL_SEQUENCE_RNN = 35, + GATHER = 36, + BATCH_TO_SPACE_ND = 37, + SPACE_TO_BATCH_ND = 38, + TRANSPOSE = 39, + MEAN = 40, + SUB = 41, + DIV = 42, + SQUEEZE = 43, + UNIDIRECTIONAL_SEQUENCE_LSTM = 44, + STRIDED_SLICE = 45, + BIDIRECTIONAL_SEQUENCE_RNN = 46, + EXP = 47, + TOPK_V2 = 48, + SPLIT = 49, + LOG_SOFTMAX = 50, + // DELEGATE is a special op type for the operations which are delegated to + // other backends. + // WARNING: Experimental interface, subject to change + DELEGATE = 51, + BIDIRECTIONAL_SEQUENCE_LSTM = 52, + CAST = 53, + PRELU = 54, + MAXIMUM = 55, + ARG_MAX = 56, + MINIMUM = 57, + LESS = 58, + NEG = 59, + PADV2 = 60, + GREATER = 61, + GREATER_EQUAL = 62, + LESS_EQUAL = 63, + SELECT = 64, + SLICE = 65, + SIN = 66, + TRANSPOSE_CONV = 67, + SPARSE_TO_DENSE = 68, + TILE = 69, + EXPAND_DIMS = 70, + EQUAL = 71, + NOT_EQUAL = 72, + LOG = 73, + SUM = 74, + SQRT = 75, + RSQRT = 76, + SHAPE = 77, + POW = 78, + ARG_MIN = 79, + FAKE_QUANT = 80, + REDUCE_PROD = 81, + REDUCE_MAX = 82, + PACK = 83, + LOGICAL_OR = 84, + ONE_HOT = 85, + LOGICAL_AND = 86, + LOGICAL_NOT = 87, + UNPACK = 88, + REDUCE_MIN = 89, + FLOOR_DIV = 90, + REDUCE_ANY = 91, + SQUARE = 92, + ZEROS_LIKE = 93, + FILL = 94, + FLOOR_MOD = 95, + RANGE = 96, + RESIZE_NEAREST_NEIGHBOR = 97, + LEAKY_RELU = 98, + SQUARED_DIFFERENCE = 99, + MIRROR_PAD = 100, + ABS = 101, + SPLIT_V = 102, + UNIQUE = 103, + CEIL = 104, + REVERSE_V2 = 105, + ADD_N = 106, + GATHER_ND = 107, + COS = 108, + WHERE = 109, + RANK = 110, + ELU = 111, + REVERSE_SEQUENCE = 112, + MATRIX_DIAG = 113, + QUANTIZE = 114, + MATRIX_SET_DIAG = 115, + ROUND = 116, + HARD_SWISH = 117, + IF = 118, + WHILE = 119, + NON_MAX_SUPPRESSION_V4 = 120, + NON_MAX_SUPPRESSION_V5 = 121, +} + +// Options for the builtin operators. +union BuiltinOptions { + Conv2DOptions, + DepthwiseConv2DOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + Pool2DOptions, + SVDFOptions, + RNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormalizationOptions, + LSTMOptions, + ResizeBilinearOptions, + CallOptions, + ReshapeOptions, + SkipGramOptions, + SpaceToDepthOptions, + EmbeddingLookupSparseOptions, + MulOptions, + PadOptions, + GatherOptions, + BatchToSpaceNDOptions, + SpaceToBatchNDOptions, + TransposeOptions, + ReducerOptions, + SubOptions, + DivOptions, + SqueezeOptions, + SequenceRNNOptions, + StridedSliceOptions, + ExpOptions, + TopKV2Options, + SplitOptions, + LogSoftmaxOptions, + CastOptions, + DequantizeOptions, + MaximumMinimumOptions, + ArgMaxOptions, + LessOptions, + NegOptions, + PadV2Options, + GreaterOptions, + GreaterEqualOptions, + LessEqualOptions, + SelectOptions, + SliceOptions, + TransposeConvOptions, + SparseToDenseOptions, + TileOptions, + ExpandDimsOptions, + EqualOptions, + NotEqualOptions, + ShapeOptions, + PowOptions, + ArgMinOptions, + FakeQuantOptions, + PackOptions, + LogicalOrOptions, + OneHotOptions, + LogicalAndOptions, + LogicalNotOptions, + UnpackOptions, + FloorDivOptions, + SquareOptions, + ZerosLikeOptions, + FillOptions, + BidirectionalSequenceLSTMOptions, + BidirectionalSequenceRNNOptions, + UnidirectionalSequenceLSTMOptions, + FloorModOptions, + RangeOptions, + ResizeNearestNeighborOptions, + LeakyReluOptions, + SquaredDifferenceOptions, + MirrorPadOptions, + AbsOptions, + SplitVOptions, + UniqueOptions, + ReverseV2Options, + AddNOptions, + GatherNdOptions, + CosOptions, + WhereOptions, + RankOptions, + ReverseSequenceOptions, + MatrixDiagOptions, + QuantizeOptions, + MatrixSetDiagOptions, + HardSwishOptions, + IfOptions, + WhileOptions, + DepthToSpaceOptions, + NonMaxSuppressionV4Options, + NonMaxSuppressionV5Options +} + +enum Padding : byte { SAME, VALID } + +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU_N1_TO_1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} + +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table Pool2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConv2DOptions { + // Parameters for DepthwiseConv version 1 or above. + padding:Padding; + stride_w:int; + stride_h:int; + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; + // Parameters for DepthwiseConv version 2 or above. + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow RNNCell. +table RNNOptions { + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow dynamic_rnn with RNNCell. +table SequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. +table BidirectionalSequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; + merge_outputs: bool; +} + +enum FullyConnectedOptionsWeightsFormat: byte { + DEFAULT = 0, + SHUFFLED4x16INT8 = 1, +} + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + // Parameters for FullyConnected version 1 or above. + fused_activation_function:ActivationFunctionType; + + // Parameters for FullyConnected version 2 or above. + weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT; + + // Parameters for FullyConnected version 5 or above. + // If set to true, then the number of dimension is preserved. Furthermore, + // all but the last dimension of the input and output shapes will be equal. + keep_num_dims: bool; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; +} + +table MulOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormalizationOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +enum LSTMKernelType : byte { + // Full LSTM kernel which supports peephole and projection. + FULL = 0, + // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. + BASIC = 1, +} + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + // Parameters for LSTM version 1 or above. + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // Parameters for LSTM version 2 or above. + // Basic kernel is only supported in version 2 or above. + kernel_type: LSTMKernelType = FULL; +} + +// An implementation of TensorFlow dynamic_rnn with LSTMCell. +table UnidirectionalSequenceLSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true then first dimension is sequence, otherwise batch. + time_major:bool; +} + +table BidirectionalSequenceLSTMOptions { + // Parameters supported by version 1: + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true, store the outputs of both directions into the first output. + merge_outputs: bool; + + // Parameters supported by version 2: + // If true then first dimension is sequence, otherwise batch. + // Version 1 implementations assumed time_major to be true, so this default + // value should never change. + time_major: bool = true; +} + +table ResizeBilinearOptions { + new_height: int (deprecated); + new_width: int (deprecated); + align_corners: bool; +} + +table ResizeNearestNeighborOptions { + align_corners: bool; +} + +// A call operation options +table CallOptions { + // The subgraph index that needs to be called. + subgraph:uint; +} + +table PadOptions { +} + +table PadV2Options { +} + +table ReshapeOptions { + new_shape:[int]; +} + +table SpaceToBatchNDOptions { +} + +table BatchToSpaceNDOptions { +} + +table SkipGramOptions { + ngram_size: int; + max_skip_size: int; + include_all_ngrams: bool; +} + +table SpaceToDepthOptions { + block_size: int; +} + +table DepthToSpaceOptions { + block_size: int; +} + +table SubOptions { + fused_activation_function:ActivationFunctionType; +} + +table DivOptions { + fused_activation_function:ActivationFunctionType; +} + +table TopKV2Options { +} + +enum CombinerType : byte { + SUM = 0, + MEAN = 1, + SQRTN = 2, +} + +table EmbeddingLookupSparseOptions { + combiner:CombinerType; +} + +table GatherOptions { + axis: int; +} + +table TransposeOptions { +} + +table ExpOptions { +} + +table CosOptions { +} + +table ReducerOptions { + keep_dims: bool; +} + +table SqueezeOptions { + squeeze_dims:[int]; +} + +table SplitOptions { + num_splits: int; +} + +table SplitVOptions { + num_splits: int; +} + +table StridedSliceOptions { + begin_mask: int; + end_mask: int; + ellipsis_mask: int; + new_axis_mask: int; + shrink_axis_mask: int; +} + +table LogSoftmaxOptions { +} + +table CastOptions { + in_data_type: TensorType; + out_data_type: TensorType; +} + +table DequantizeOptions { +} + +table MaximumMinimumOptions { +} + +table TileOptions { +} + +table ArgMaxOptions { + output_type : TensorType; +} + +table ArgMinOptions { + output_type : TensorType; +} + +table GreaterOptions { +} + +table GreaterEqualOptions { +} + +table LessOptions { +} + +table LessEqualOptions { +} + +table NegOptions { +} + +table SelectOptions { +} + +table SliceOptions { +} + +table TransposeConvOptions { + padding:Padding; + stride_w:int; + stride_h:int; +} + +table ExpandDimsOptions { +} + +table SparseToDenseOptions { + validate_indices:bool; +} + +table EqualOptions { +} + +table NotEqualOptions { +} + +table ShapeOptions { + // Optional output type of the operation (int32 or int64). Defaults to int32. + out_type : TensorType; +} + +table RankOptions { +} + +table PowOptions { +} + +table FakeQuantOptions { + // Parameters supported by version 1: + min:float; + max:float; + num_bits:int; + + // Parameters supported by version 2: + narrow_range:bool; +} + +table PackOptions { + values_count:int; + axis:int; +} + +table LogicalOrOptions { +} + +table OneHotOptions { + axis:int; +} + +table AbsOptions { +} + + +table HardSwishOptions { +} + +table LogicalAndOptions { +} + +table LogicalNotOptions { +} + +table UnpackOptions { + num:int; + axis:int; +} + +table FloorDivOptions { +} + +table SquareOptions { +} + +table ZerosLikeOptions { +} + +table FillOptions { +} + +table FloorModOptions { +} + +table RangeOptions { +} + +table LeakyReluOptions { + alpha:float; +} + +table SquaredDifferenceOptions { +} + +enum MirrorPadMode : byte { + // Doesn't include borders. + REFLECT = 0, + // Includes borders. + SYMMETRIC = 1, +} + +table MirrorPadOptions { + mode:MirrorPadMode; +} + +table UniqueOptions { + idx_out_type:TensorType = INT32; +} + +table ReverseV2Options { +} + +table AddNOptions { +} + +table GatherNdOptions { +} + +table WhereOptions { +} + +table ReverseSequenceOptions { + seq_dim:int; + batch_dim:int = 0; +} + +table MatrixDiagOptions { +} + +table QuantizeOptions { +} + +table MatrixSetDiagOptions { +} + +table IfOptions { + then_subgraph_index:int; + else_subgraph_index:int; +} + +table WhileOptions { + cond_subgraph_index:int; + body_subgraph_index:int; +} + +table NonMaxSuppressionV4Options { +} + +table NonMaxSuppressionV5Options { +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + builtin_code:BuiltinOperator; + custom_code:string; + + // The version of the operator. The version need to be bumped whenever new + // parameters are introduced into an op. + version:int = 1; +} + +enum CustomOptionsFormat : byte { + FLEXBUFFERS = 0, +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:uint; + + // Optional input and output tensors are indicated by -1. + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; + custom_options_format:CustomOptionsFormat; + + // A list of booleans indicating the input tensors which are being mutated by + // this operator.(e.g. used by RNN and LSTM). + // For example, if the "inputs" array refers to 5 tensors and the second and + // fifth are mutable variables, then this list will contain + // [false, true, false, false, true]. + // + // If the list is empty, no variable is mutated in this operator. + // The list either has the same length as `inputs`, or is empty. + mutating_variable_inputs:[bool]; + + // A list of indices to the subgraph's "tensors" that are internal to an Op. + // Internal tensors are those that do not flow in or out of the operation, + // but instead are part of internal computation. As such, the operation's + // implementation may manage its memory more efficiently. They are needed + // however (i.e. not just an implementation detail) since they are part of the + // computation, which may require relevant metadata such as quantization + // parameters. + intermediates:[int]; +} + +// The root type, defining a subgraph, which typically represents an entire +// model. +table SubGraph { + // A list of all tensors used in this subgraph. + tensors:[Tensor]; + + // Indices of the tensors that are inputs into this subgraph. Note this is + // the list of non-static tensors that feed into the subgraph for inference. + inputs:[int]; + + // Indices of the tensors that are outputs out of this subgraph. Note this is + // the list of output tensors that are considered the product of the + // subgraph's inference. + outputs:[int]; + + // All operators, in execution order. + operators:[Operator]; + + // Name of this subgraph (used for debugging). + name:string; +} + +// Table of raw data buffers (used for constant tensors). Referenced by tensors +// by index. The generous alignment accommodates mmap-friendly data structures. +table Buffer { + data:[ubyte] (force_align: 16); +} + +table Metadata { + // A human readable string to uniquely identify a Metadata. + name:string; + // An index to the buffers table. + buffer:uint; +} + +table Model { + // Version of the schema. + version:uint; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All the subgraphs of the model. The 0th is assumed to be the main + // model. + subgraphs:[SubGraph]; + + // A description of the model. + description:string; + + // Buffers of the model. + // Note the 0th entry of this array must be an empty buffer (sentinel). + // This is a convention so that tensors without a buffer can provide 0 as + // their buffer. + buffers:[Buffer]; + + // Metadata about the model. Indirects into the existings buffers list. + // Deprecated, prefer to use metadata field. + metadata_buffer:[int]; + + // Metadata about the model. + metadata:[Metadata]; +} + +root_type Model; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.cc new file mode 100644 index 00000000000..377ecab1676 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteAddParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + // MS_LOGD("parse TfliteAddParser"); + std::unique_ptr attr(new schema::AddT()); + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Add; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteAddParser("Add", new TfliteAddParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.h new file mode 100644 index 00000000000..5add90acf59 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_add_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_ADD_PARSER_H +#define PREDICT_TFLITE_ADD_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteAddParser : public TfliteNodeParser { + public: + TfliteAddParser() : TfliteNodeParser("Add") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_ADD_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc new file mode 100644 index 00000000000..abef7bac380 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteArgmaxParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; + std::unique_ptr attr(new schema::ArgMaxT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_ArgMax; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteArgmaxParser("Argmax", new TfliteArgmaxParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h new file mode 100644 index 00000000000..a96305e9117 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_ARGMAX_PARSER_H +#define PREDICT_TFLITE_ARGMAX_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteArgmaxParser : public TfliteNodeParser { + public: + TfliteArgmaxParser() : TfliteNodeParser("Argmax") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_ARGMAX_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc new file mode 100644 index 00000000000..50454af12eb --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteConcatParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + // MS_LOGD("parse TfliteConcatParser"); + std::unique_ptr attr(new schema::ConcatT()); + const auto &tfliteAttr = tfliteOp->builtin_options.AsConcatenationOptions(); + if (tfliteAttr == nullptr) { + // MS_LOGE("get op: %s attr failed", op->name.c_str()); + return RET_NULL_PTR; + } + + attr->axis = tfliteAttr->axis; + attr->n = tfliteOp->inputs.size(); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Concat; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteConcatParser("Concat", new TfliteConcatParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h new file mode 100644 index 00000000000..47e767f0254 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_CONCAT_PARSER_H +#define PREDICT_TFLITE_CONCAT_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteConcatParser : public TfliteNodeParser { + public: + TfliteConcatParser() : TfliteNodeParser("Concat") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_CONCAT_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc new file mode 100644 index 00000000000..9bdcac401db --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteConvParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) { + // MS_LOGD("parse TfliteConvParser"); + std::unique_ptr attr(new schema::Conv2DT()); + const auto &tfliteAttr = tflite_op->builtin_options.AsConv2DOptions(); + if (tfliteAttr == nullptr) { + // MS_LOGE("get op: %s attr failed", op->name.c_str()); + return RET_NULL_PTR; + } + attr->group = 1; + attr->strideW = tfliteAttr->stride_w; + attr->strideH = tfliteAttr->stride_h; + attr->dilateH = tfliteAttr->dilation_h_factor; + attr->dilateW = tfliteAttr->dilation_w_factor; + attr->padMode = GetPadMode(tfliteAttr->padding); + attr->format = schema::Format_NHWC; + attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function); + // get the conv op weight tensor + auto weight_index = tflite_op->inputs[1]; + const auto &weight_tensor = tflite_tensors[weight_index]; + std::vector weight_tensors{weight_tensor.get()}; + + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { + // MS_LOGE("parse weight failed"); + return RET_ERROR; + } + auto weight_shape = weight_tensor->shape; + attr->channelIn = weight_shape[KHWC_C]; + attr->channelOut = weight_shape[KHWC_K]; + attr->kernelW = weight_shape[KHWC_W]; + attr->kernelH = weight_shape[KHWC_H]; + if (tflite_op->inputs.size() == 3) { + attr->hasBias = true; + auto bias_index = tflite_op->inputs[2]; + const auto &bias_tensor = tflite_tensors[bias_index]; + std::vector bias_tensors{bias_tensor.get()}; + if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) { + // MS_LOGE("parse bias failed"); + return RET_ERROR; + } + } + // calculate pad params + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Conv2D; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteConv2DParser("Conv2D", new TfliteConvParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h new file mode 100644 index 00000000000..421d696829f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_CONV_PARSER_H +#define PREDICT_TFLITE_CONV_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteConvParser : public TfliteNodeParser { + public: + TfliteConvParser() : TfliteNodeParser("Conv2D") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_CONV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.cc new file mode 100644 index 00000000000..11ce6cec819 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/tflite/tflite_converter.h" + +namespace mindspore { +namespace lite { +TfliteConverter::TfliteConverter() { + modelParser = new TfliteModelParser(); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h new file mode 100644 index 00000000000..4c85595ffb7 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_converter.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_CAFFE_CONVERTER_H_ +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_CAFFE_CONVERTER_H_ + +#include +#include +#include "mindspore/lite/tools/converter/converter.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h" +#include "mindspore/lite/tools/converter/graphdef_transform.h" + +namespace mindspore { +namespace lite { +class TfliteConverter : public Converter { + public: + TfliteConverter(); + + ~TfliteConverter() override = default; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_CAFFE_CONVERTER_H_ + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc new file mode 100644 index 00000000000..7af495e65da --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc @@ -0,0 +1,140 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h" +#include "tools/common/node_util.h" + +namespace mindspore { +namespace lite { +STATUS TfliteDepthwiseConv2DParser::ParseGroupDepthwiseConv(schema::CNodeT *op, + const std::unique_ptr &attr, + const std::unique_ptr &weightTensor, + TensorCache *tensor_cache) { + std::unique_ptr convAttr(new schema::Conv2DT); + convAttr->format = attr->format; + convAttr->channelIn = attr->channelIn; + convAttr->channelOut = attr->channelIn * attr->channelMultiplier; + convAttr->kernelH = attr->kernelH; + convAttr->kernelW = attr->kernelW; + convAttr->strideH = attr->strideH; + convAttr->strideW = attr->strideW; + convAttr->padMode = attr->padMode; + convAttr->padUp = attr->padUp; + convAttr->padDown = attr->padDown; + convAttr->padLeft = attr->padLeft; + convAttr->padRight = attr->padRight; + convAttr->dilateH = attr->dilateH; + convAttr->dilateW = attr->dilateW; + convAttr->hasBias = attr->hasBias; + convAttr->activationType = attr->activationType; + + auto weightTensorIndex = tensor_cache->FindTensor(weightTensor->name); + if (weightTensorIndex >= 0 && weightTensorIndex < tensor_cache->GetCachedTensor().size()) { + auto liteWeightTensor = tensor_cache->GetCachedTensor()[weightTensorIndex]; + if (liteWeightTensor->dataType == TypeId::kNumberTypeUInt8) { + // convert weight format KHWC -> CHWK + auto status = TransFilterFormat(liteWeightTensor, kKHWC2CHWK); + if (status != RET_OK) { + MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; + return RET_ERROR; + } + } + + if (liteWeightTensor->dataType == kNumberTypeFloat32 || liteWeightTensor->dataType == kNumberTypeFloat) { + // convert weight format KHWC -> CHWK + auto status = TransFilterFormat(liteWeightTensor, kKHWC2CHWK); + if (status != RET_OK) { + MS_LOG(ERROR) << "Trans depthwiseConv Filter Format failed."; + return RET_ERROR; + } + } + } + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Conv2D; + op->primitive->value.value = convAttr.release(); + return RET_OK; +} + +STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tflite_tensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + // MS_LOGD("parse TfliteDepthwiseConv2DParser"); + std::unique_ptr attr(new schema::DepthwiseConv2DT()); + const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions(); + if (tflite_attr == nullptr) { + // MS_LOGE("get op: %s attr failed", op->name.c_str()); + return RET_NULL_PTR; + } + attr->strideW = tflite_attr->stride_w; + attr->strideH = tflite_attr->stride_h; + attr->dilateH = tflite_attr->dilation_h_factor; + attr->dilateW = tflite_attr->dilation_w_factor; + attr->padMode = GetPadMode(tflite_attr->padding); + attr->format = schema::Format_NHWC; + attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function); + // get the conv op weight tensor + auto input_index = tflite_op->inputs[0]; + const auto &input_tenosr = tflite_tensors[input_index]; + auto input_shape = input_tenosr->shape; + + auto weight_index = tflite_op->inputs[1]; + const auto &weight_tensor = tflite_tensors[weight_index]; + auto weight_shape = weight_tensor->shape; + attr->channelIn = input_shape[KHWC_C]; + attr->channelMultiplier = tflite_attr->depth_multiplier; + attr->kernelH = weight_shape[KHWC_H]; + attr->kernelW = weight_shape[KHWC_W]; + + std::vector weight_tensors{weight_tensor.get()}; + + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { + // MS_LOGE("parse weight failed"); + return RET_ERROR; + } + + if (tflite_op->inputs.size() == 3) { + attr->hasBias = true; + auto bias_index = tflite_op->inputs[2]; + const auto &bias_tensor = tflite_tensors[bias_index]; + std::vector bias_tensors{bias_tensor.get()}; + if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) { + // MS_LOGE("parse bias failed"); + return RET_ERROR; + } + } + + if (attr->channelMultiplier > 1) { + if (RET_OK != ParseGroupDepthwiseConv(op, attr, weight_tensor, tensor_cache)) { + // MS_LOGE("Parse Group DepthwiseConv failed"); + return RET_ERROR; + } + } else { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteDepthwiseConv2DParser("DepthwiseConv2D", new TfliteDepthwiseConv2DParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h new file mode 100644 index 00000000000..0c28b3fd454 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H +#define PREDICT_TFLITE_DEPTHWISE_CONV_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteDepthwiseConv2DParser : public TfliteNodeParser { + public: + TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) override; + + private: + STATUS ParseGroupDepthwiseConv(schema::CNodeT *op, + const std::unique_ptr &attr, + const std::unique_ptr &weightTensor, + TensorCache *tensor_cache); +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_CONV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc new file mode 100644 index 00000000000..0e8d1721df2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tflite/tflite_fakequant_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + // MS_LOGD("parse TfliteFullyConnectedParser"); + std::unique_ptr attr(new schema::FullConnectionT()); + + auto weight_index = tfliteOp->inputs[1]; + const auto &weight_tensor = tfliteTensors[weight_index]; + + std::vector weight_tensors{weight_tensor.get()}; + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_NHWC)) { + // MS_LOGE("parse weight failed"); + return RET_ERROR; + } + if (tfliteOp->inputs.size() == 3) { + attr->hasBias = true; + auto bias_index = tfliteOp->inputs[2]; + const auto &bias_tensor = tfliteTensors[bias_index]; + std::vector bias_tensors{bias_tensor.get()}; + if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) { + // MS_LOGE("parse bias failed"); + return RET_ERROR; + } + } + attr->axis = 1; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_FullConnection; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFakeQuantParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h new file mode 100644 index 00000000000..e4c0441a55b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fakequant_parser.h @@ -0,0 +1,39 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LITE_TFLITE_FAKEQUANT_PARSER_H +#define LITE_TFLITE_FAKEQUANT_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteFakeQuantParser : public TfliteNodeParser { + public: + TfliteFakeQuantParser() : TfliteNodeParser("FakeQuant") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_FAKEQUANT_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc new file mode 100644 index 00000000000..1ac1f6a6baf --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + // MS_LOGD("parse TfliteFullyConnectedParser"); + std::unique_ptr attr(new schema::FullConnectionT()); + + auto weight_index = tfliteOp->inputs[1]; + const auto &weight_tensor = tfliteTensors[weight_index]; + + std::vector weight_tensors{weight_tensor.get()}; + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_NHWC)) { + // MS_LOGE("parse weight failed"); + return RET_ERROR; + } + if (tfliteOp->inputs.size() == 3) { + attr->hasBias = true; + auto bias_index = tfliteOp->inputs[2]; + const auto &bias_tensor = tfliteTensors[bias_index]; + std::vector bias_tensors{bias_tensor.get()}; + if (RET_OK != ParseBias(bias_tensors, tfliteModelBuffer, tensor_cache)) { + // MS_LOGE("parse bias failed"); + return RET_ERROR; + } + } + attr->axis = 1; + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_FullConnection; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteFullyConnectedParser("FullyConnected", new TfliteFullyConnectedParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h new file mode 100644 index 00000000000..9906e7d90ee --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_ADD_PARSER_H +#define PREDICT_TFLITE_ADD_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteFullyConnectedParser : public TfliteNodeParser { + public: + TfliteFullyConnectedParser() : TfliteNodeParser("FullyConnected") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_ADD_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logistic_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_logistic_parser.cc new file mode 100644 index 00000000000..4c5fa83a294 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logistic_parser.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_logistic_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteLogisticParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + // MS_LOGD("parse TfliteLogisticParser"); + std::unique_ptr attr(new schema::ActivationT()); + attr->type = schema::ActivationType_SIGMOID; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logistic_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_logistic_parser.h new file mode 100644 index 00000000000..3e98637a0c5 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logistic_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_LOGISTIC_PARSER_H +#define PREDICT_TFLITE_LOGISTIC_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteLogisticParser : public TfliteNodeParser { + public: + TfliteLogisticParser() : TfliteNodeParser("Logistic") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_CONCAT_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_max_pooling_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_max_pooling_parser.cc new file mode 100644 index 00000000000..153a5d8b2a4 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_max_pooling_parser.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_max_pooling_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteMaxPoolingParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) { + // MS_LOGD("paser TfliteMaxPoolingParser"); + std::unique_ptr attr(new schema::PoolingT()); + const auto &tflite_attr = tflite_op->builtin_options.AsPool2DOptions(); + if (tflite_attr == nullptr) { + // MS_LOGE("get op: %s attr failed", op->name.c_str()); + } + attr->format = schema::Format_NHWC; + // attr->global + attr->poolingMode = schema::PoolMode_MAX_POOLING; + attr->windowW = tflite_attr->filter_width; + attr->windowH = tflite_attr->filter_height; + attr->strideW = tflite_attr->stride_w; + attr->strideH = tflite_attr->stride_h; + attr->padMode = GetPadMode(tflite_attr->padding); + // calculate pad params + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Pooling; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteMaxPoolingParser("MaxPooling", new TfliteMaxPoolingParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_max_pooling_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_max_pooling_parser.h new file mode 100644 index 00000000000..dc07947ba8a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_max_pooling_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_MAX_POOLING_PARSER_H +#define PREDICT_TFLITE_MAX_POOLING_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteMaxPoolingParser : public TfliteNodeParser { + public: + TfliteMaxPoolingParser() : TfliteNodeParser("MaxPooling") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_CONV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_mean_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_mean_parser.cc new file mode 100644 index 00000000000..2eb80d60821 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_mean_parser.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_mean_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteMeanParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + // MS_LOGI("paser TfliteMeanParser"); + std::unique_ptr attr(new schema::MeanT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsReducerOptions(); + if (tflite_attr == nullptr) { + // MS_LOGE("get op: %s attr failed", op->name.c_str()); + } + attr->keepDims = tflite_attr->keep_dims; + if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->axis)) { + return RET_ERROR; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Mean; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteMeanParser("Mean", new TfliteMeanParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_mean_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_mean_parser.h new file mode 100644 index 00000000000..ec5e3a1644f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_mean_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_MEAN_PARSER_H +#define PREDICT_TFLITE_MEAN_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteMeanParser : public TfliteNodeParser { + public: + TfliteMeanParser() : TfliteNodeParser("Mean") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_MEAN_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_mean_pooling_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_mean_pooling_parser.cc new file mode 100644 index 00000000000..1fb7cf254fd --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_mean_pooling_parser.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_mean_pooling_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteMeanPoolingParser::Parse(const std::unique_ptr &tflite_op, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + // MS_LOGD("paser TfliteMeanPoolingParser"); + std::unique_ptr attr(new schema::PoolingT()); + const auto &tflite_attr = tflite_op->builtin_options.AsPool2DOptions(); + if (tflite_attr == nullptr) { + // MS_LOGE("get op: %s attr failed", op->name.c_str()); + } + attr->format = schema::Format_NHWC; + // attr->global + attr->poolingMode = schema::PoolMode_MEAN_POOLING; + attr->windowW = tflite_attr->filter_width; + attr->windowH = tflite_attr->filter_height; + attr->strideW = tflite_attr->stride_w; + attr->strideH = tflite_attr->stride_h; + attr->padMode = GetPadMode(tflite_attr->padding); + // calculate pad params + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Pooling; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteMeanPoolingParser("MeanPooling", new TfliteMeanPoolingParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_mean_pooling_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_mean_pooling_parser.h new file mode 100644 index 00000000000..a1208c6d1cb --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_mean_pooling_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_MEAN_POOLING_PARSER_H +#define PREDICT_TFLITE_MEAN_POOLING_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteMeanPoolingParser : public TfliteNodeParser { + public: + TfliteMeanPoolingParser() : TfliteNodeParser("MeanPooling") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_CONV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc new file mode 100644 index 00000000000..88a08fd0dca --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -0,0 +1,251 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h" +#include +#include +#include +#include "tools/common/graph_util.h" +#include "tools/common/storage.h" +#include "flatbuffers/flatbuffers.h" +#include "utils/log_adapter.h" +#include "src/common/file_utils.h" + +namespace mindspore { +namespace lite { +TfliteModelParser::TfliteModelParser() {} + +TfliteModelParser::~TfliteModelParser() {} + +std::unique_ptr TfliteModelParser::ReadTfliteModelFromFlat(const char *model_path) { + size_t size; + auto buf = ReadFile(model_path, &size); + if (buf == nullptr) { + // MS_LOGE("the file buffer is nullptr"); + return nullptr; + } + flatbuffers::Verifier verify((const uint8_t *)buf, size); + if (!tflite::VerifyModelBuffer(verify)) { + // MS_LOGE("the buffer is invalid and fail to create graph"); + return nullptr; + } + return tflite::UnPackModel(buf); +} + +std::string TfliteModelParser::GetTfliteNodeType(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model) { + auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; + auto msOpType = GetMSOpType(tflite_op_type); + return msOpType; +} + +STATUS TfliteModelParser::SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *sub_graphDef) { + std::vector tensors = tensor_cache.GetCachedTensor(); + for (auto iter : tensors) { + std::unique_ptr temp(iter); + temp->format = schema::Format_NHWC; + sub_graphDef->allTensors.emplace_back(move(temp)); + } + return RET_OK; +} + +STATUS TfliteModelParser::ParseTfliteQuantParams(const std::unique_ptr &tflite_subgraph, + const std::unique_ptr &tflite_op) { + auto dst_op = tfliteOpMap.at(tflite_op.get()); + + std::vector quant_params_index; + quant_params_index.insert(quant_params_index.end(), tflite_op->inputs.begin(), tflite_op->inputs.end()); + quant_params_index.insert(quant_params_index.end(), tflite_op->outputs.begin(), tflite_op->outputs.end()); + for (const auto &index : quant_params_index) { + const auto &tflite_tensor = tflite_subgraph->tensors[index]; + if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && + tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) { + continue; + } + std::unique_ptr quant_param(new schema::QuantParamT()); + if (!tflite_tensor->quantization->scale.empty()) { + quant_param->scale = tflite_tensor->quantization->scale[0]; + } + + if (!tflite_tensor->quantization->zero_point.empty()) { + quant_param->zeroPoint = tflite_tensor->quantization->zero_point[0]; + } + + if (!tflite_tensor->quantization->min.empty()) { + quant_param->min = tflite_tensor->quantization->min[0]; + } + + if (!tflite_tensor->quantization->max.empty()) { + quant_param->max = tflite_tensor->quantization->max[0]; + } + } + dst_op->quantType = schema::QuantType_AwareTrainning; + return RET_OK; +} + +STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr &tflite_subgraph, + const std::unique_ptr &tflite_op, schema::CNodeT *op, + TensorCache *tensorCache) { + for (const auto &index : tflite_op->outputs) { + const auto &tflite_tensor = tflite_subgraph->tensors[index]; + std::unique_ptr tensor(new schema::TensorT()); + tensor->dataType = GetTfliteDataType(tflite_tensor->type); + tensor->dims = tflite_tensor->shape; + tensor->nodeType = schema::NodeType_Parameter; + auto opOutputIndex = tensorCache->AddTensor(tflite_tensor->name, tensor.release(), OP_OUTPUT); + op->outputIndex.emplace_back(opOutputIndex); + } + + return RET_OK; +} + +STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr &tflite_subgraph, + const std::unique_ptr &tflite_op, TensorCache *tensorCache) { + for (const auto &tfliteIndex : tflite_op->inputs) { + const auto &tflite_tensor = tflite_subgraph->tensors[tfliteIndex]; + auto tensor_name = tflite_tensor->name; + auto op = tfliteOpMap[tflite_op.get()]; + unsigned int index = tensorCache->FindTensor(tensor_name); + if (index != -1) { + op->inputIndex.push_back(index); + } + } + + return RET_OK; +} + +STATUS TfliteModelParser::ParseOp(const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, + schema::MetaGraphT *subGraph, + mindspore::lite::TensorCache *tensorCache) { + auto i = 0; + for (const auto &tflite_op : tflite_subgraph->operators) { + auto opType = GetTfliteNodeType(tflite_op, tflite_model); + + std::unique_ptr op(new schema::CNodeT); + op->name = opType + "-" + std::to_string(i++); + + // MS_LOGD("parse op: [%s]", op->name.c_str()); + + // 1. init op attr params + auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType); + if (node_parser == nullptr) { + // MS_LOGE("node %s parser is nullptr", opType.c_str()); + return RET_NULL_PTR; + } + + auto status = node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, + tflite_model->operator_codes, op.get(), tensorCache, false); + if (status != RET_OK) { + // MS_LOGE("node %s parser failed", opType.c_str()); + return RET_ERROR; + } + + status = SetOpOutputIdx(tflite_subgraph, tflite_op, op.get(), tensorCache); + if (status != RET_OK) { + // MS_LOGE("Set Op %s Output Index Failed!", op->name.c_str()); + return RET_ERROR; + } + + subGraph->nodes.emplace_back(std::move(op)); + opMap[subGraph->nodes.back()->name] = subGraph->nodes.back().get(); + tfliteOpMap[tflite_op.get()] = subGraph->nodes.back().get(); + } + return RET_OK; +} + +void TfliteModelParser::SetInputTensor(const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, + TensorCache *tensor_cache) { + for (const auto &index : tflite_subgraph->inputs) { + const auto &tflite_tensor = tflite_subgraph->tensors[index]; + std::unique_ptr tensor(new schema::TensorT()); + tensor->format = schema::Format_NHWC; + tensor->dataType = GetTfliteDataType(tflite_tensor->type); + tensor->nodeType = schema::NodeType_ValueNode; + tensor->dims = tflite_tensor->shape; + tensor_cache->AddTensor(tflite_tensor->name, tensor.release(), GRAPH_INPUT); + } +} + +void TfliteModelParser::SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, + schema::MetaGraphT *subGraphDef) { + auto opGraph = OpGraphT::Build(subGraphDef); + auto graphInputs = tensorCache.GetGraphInputs(); + auto graphOutputs = opGraph->GetOutputNode(); + + subGraphDef->inputIndex.assign(graphInputs.begin(), graphInputs.end()); + + for (const auto &output : graphOutputs) { + auto op = opMap[output->ID()]; + for (auto outputIndex : op->outputIndex) { + subGraphDef->outputIndex.emplace_back(outputIndex); + } + } +} + +MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::string &weightFile) { + std::unique_ptr subGraph(new schema::MetaGraphT); + if (ValidateFileStr(modelFile, ".tflite") != RET_OK) { + // MS_LOGE("INPUT ILLEGAL: modelFile must be *.tflite"); + return nullptr; + } + std::unique_ptr tflite_model(new tflite::ModelT()); + tflite_model = ReadTfliteModelFromFlat(modelFile.c_str()); + if (tflite_model == nullptr) { + // MS_LOGE("read tflite model failed"); + return nullptr; + } + TensorCache tensorCache; + if (tflite_model->subgraphs.size() != 1) { + MS_LOG(ERROR) << "read tflite model subgraphs failed"; + return nullptr; + } + + const auto &tflite_subgraph = tflite_model->subgraphs[0]; + subGraph->name = "MS_model converted by TF-Lite"; + + // set dst subGraph input/output tensor + SetInputTensor(tflite_model, tflite_subgraph, &tensorCache); + // set dst subGraph op attr etc. + auto status = ParseOp(tflite_model, tflite_subgraph, subGraph.get(), &tensorCache); + if (status != RET_OK) { + // MS_LOGE("ParseOp failed."); + return nullptr; + } + + for (const auto &tflite_op : tflite_subgraph->operators) { + auto statusTmp = SetOpInputIdx(tflite_subgraph, tflite_op, &tensorCache); + if (statusTmp != RET_OK) { + // MS_LOGE("Set Op %s Input Index Failed!", tfliteOpMap.at(tflite_op.get())->name.c_str()); + } + } + + for (const auto &tflite_op : tflite_subgraph->operators) { + auto statusTmp = ParseTfliteQuantParams(tflite_subgraph, tflite_op); + if (statusTmp != RET_OK) { + // MS_LOGE("ParseTfliteQuantParams %s Failed!", tfliteOpMap.at(tflite_op.get())->name.c_str()); + } + } + + SetGraphTensorIndex(tensorCache, subGraph.get()); + SetAllTensors(tensorCache, subGraph.get()); + return subGraph.release(); +// return Fb2Anf(subGraph.release()); +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h new file mode 100644 index 00000000000..20c5a73f483 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "securec/include/securec.h" +#include "mindspore/lite/tools/converter/model_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" +#include "tools/common/tensor_util.h" + +#include "mindspore/lite/schema/inner/model_generated.h" + +// using namespace tflite; + +namespace mindspore { +namespace lite { +class TfliteModelParser : public ModelParser { + public: + TfliteModelParser(); + + virtual ~TfliteModelParser(); + + MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile); + + private: + std::unique_ptr ReadTfliteModelFromFlat(const char *buf); + + void SetInputTensor(const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, TensorCache *tensor_cache); + + void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, + schema::MetaGraphT *subGraphDef); + + STATUS ParseOp(const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::MetaGraphT *sub_graph, + TensorCache *tensor_cache); + + STATUS ParseTfliteQuantParams(const std::unique_ptr &tflite_subgraph, + const std::unique_ptr &tflite_op); + + std::string GetTfliteNodeType(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model); + + STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *sub_graph); + + STATUS SetOpOutputIdx(const std::unique_ptr &tflite_subgraph, + const std::unique_ptr &tflite_op, + schema::CNodeT *op, + TensorCache *tensorCache); + + STATUS SetOpInputIdx(const std::unique_ptr &tflite_subgraph, + const std::unique_ptr &tflite_op, TensorCache *tensorCache); + + std::map opMap; + std::map tfliteOpMap; +}; +} // namespace lite +} // namespace mindspore +#endif // PREDICT_CONV +// ERTER_PARSER_TFLITE_MODEL_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_mul_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_mul_parser.cc new file mode 100644 index 00000000000..ab9198819cb --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_mul_parser.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_mul_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteMulParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + // MS_LOGD("parse TfliteMulParser"); + std::unique_ptr attr(new schema::MulT()); + auto weight_index = tfliteOp->inputs[1]; + const auto &weight_tensor = tfliteTensors[weight_index]; + std::vector weight_tensors{weight_tensor.get()}; + + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { + // MS_LOGE("parse weight failed"); + return RET_ERROR; + } + + const auto &tfliteAttr = tfliteOp->builtin_options.AsMulOptions(); + if (tfliteAttr == nullptr) { + // MS_LOGE("get op: %s attr failed", op->name.c_str()); + return RET_NULL_PTR; + } + // tfliteAttr->fused_activation_function + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Mul; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteMulParser("Mul", new TfliteMulParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_mul_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_mul_parser.h new file mode 100644 index 00000000000..90906df3ea3 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_mul_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_MUL_PARSER_H +#define PREDICT_TFLITE_MUL_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteMulParser : public TfliteNodeParser { + public: + TfliteMulParser() : TfliteNodeParser("Mul") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_MUL_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc new file mode 100644 index 00000000000..3b0265cf875 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "securec/include/securec.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteNodeParser::CopyTfliteTensorData(const std::vector> &tfliteModelBuffer, + const tflite::TensorT *tflite_tensor, schema::TensorT *tensor) { + auto count = 1; + std::for_each(tflite_tensor->shape.begin(), tflite_tensor->shape.end(), [&](int32_t sha) { count *= sha; }); + auto data_size = count * GetDataTypeSize(TypeId(tensor->dataType)); + auto buffer_idx = tflite_tensor->buffer; + if (!tfliteModelBuffer[buffer_idx]->data.empty()) { + tensor->data.resize(data_size); + auto ret = memcpy_s(tensor->data.data(), data_size, tfliteModelBuffer[buffer_idx]->data.data(), data_size); + if (ret) { + // MS_LOGE("memcpy tensor data failed, error code: %d", ret); + return ret; + } + } else { + // MS_LOGE("src tensor data is empty."); + return RET_ERROR; + } + return RET_OK; +} + +STATUS TfliteNodeParser::ParseWeight(const std::vector &weight_tenosrs, + const std::vector> &tfliteModelBuffer, + mindspore::lite::TensorCache *tensor_cache, schema::Format format) { + for (const auto &weight_tensor : weight_tenosrs) { + auto idx = tensor_cache->FindTensor(weight_tensor->name); + if (idx < 0) { + std::unique_ptr tensor(new schema::TensorT); + tensor->dataType = GetTfliteDataType(weight_tensor->type); + tensor->dims = weight_tensor->shape; + tensor->nodeType = schema::NodeType_ValueNode; + // memcpy tensor data + // buffer is 0 (which refers to an always existent empty buffer) + if (weight_tensor->buffer > 0) { + CopyTfliteTensorData(tfliteModelBuffer, weight_tensor, tensor.get()); + } + // MS_LOGD("add weight tensor name: %s", weight_tensor->name.c_str()); + tensor_cache->AddTensor(weight_tensor->name, tensor.release(), TF_CONST); + } + } + return RET_OK; +} + +STATUS TfliteNodeParser::ParseBias(const std::vector &bias_tensors, + const std::vector> &tfliteModelBuffer, + TensorCache *tensor_cache) { + for (const auto &bias_tensor : bias_tensors) { + auto idx = tensor_cache->FindTensor(bias_tensor->name); + if (idx < 0) { + std::unique_ptr tensor(new schema::TensorT); + tensor->dataType = GetTfliteDataType(bias_tensor->type); + tensor->dims = bias_tensor->shape; + tensor->nodeType = schema::NodeType_ValueNode; + // memcpy tensor data + // buffer is 0 (which refers to an always existent empty buffer) + if (bias_tensor->buffer > 0) { + CopyTfliteTensorData(tfliteModelBuffer, bias_tensor, tensor.get()); + } + // MS_LOGD("add weight tensor name: %s", bias_tensor->name.c_str()); + tensor_cache->AddTensor(bias_tensor->name, tensor.release(), TF_CONST); + } + } + return RET_OK; +} + +TypeId TfliteNodeParser::GetTfliteDataType(const tflite::TensorType &tflite_data_type) { + static std::unordered_map type_map = { + {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, + {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, + {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, + }; + auto iter = type_map.find(tflite_data_type); + if (iter == type_map.end()) { + return kTypeUnknown; + } + return iter->second; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h new file mode 100644 index 00000000000..45fc943b1b4 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h @@ -0,0 +1,129 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_NODE_PARSER_H +#define PREDICT_TFLITE_NODE_PARSER_H + +#include +#include +#include +#include "utils/log_adapter.h" +#include "schema/inner/model_generated.h" +#include "tools/converter/parser/tflite/tflite_util.h" +#include "tools/converter/parser/tflite/schema_generated.h" +#include "tools/common/tensor_util.h" +#include "ir/dtype/type_id.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +class TfliteNodeParser { + public: + explicit TfliteNodeParser(const std::string &nodeName) : name(nodeName) {} + + virtual ~TfliteNodeParser() {} + + virtual STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) = 0; + + STATUS ParseWeight(const std::vector &weight_tenosr, + const std::vector> &tfliteModelBuffer, TensorCache *tensor_cache, + schema::Format format); + + STATUS ParseBias(const std::vector &weight_tenosr, + const std::vector> &tfliteModelBuffer, TensorCache *tensor_cache); + + STATUS CopyTfliteTensorData(const std::vector> &tfliteModelBuffer, + const tflite::TensorT *tflite_tensor, schema::TensorT *tensor); + + TypeId GetTfliteDataType(const tflite::TensorType &tflite_data_type); + + template + STATUS GetTfliteData(const int32_t tensor_index, const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + std::vector &attr_data) { + int32_t count = 1; + std::for_each(tfliteTensors[tensor_index]->shape.begin(), tfliteTensors[tensor_index]->shape.end(), + [&](int32_t sha) { count *= sha; }); + auto &buf_data = tfliteModelBuffer[tfliteTensors[tensor_index]->buffer]; + auto data_ptr = buf_data->data.data(); + switch (tfliteTensors[tensor_index]->type) { + case tflite::TensorType_UINT8: { + for (int i = 0; i < count; i++) { + uint8_t data = *(static_cast(static_cast(data_ptr))); + attr_data.emplace_back(static_cast(data)); + data_ptr += sizeof(uint8_t); + } + break; + } + case tflite::TensorType_INT8: { + for (int i = 0; i < count; i++) { + int8_t data = *(static_cast(static_cast(data_ptr))); + attr_data.emplace_back(static_cast(data)); + data_ptr += sizeof(int8_t); + } + break; + } + case tflite::TensorType_INT16: { + for (int i = 0; i < count; i++) { + int16_t data = *(static_cast(static_cast(data_ptr))); + attr_data.emplace_back(static_cast(data)); + data_ptr += sizeof(int16_t); + } + break; + } + case tflite::TensorType_INT32: { + for (int i = 0; i < count; i++) { + int32_t data = *(static_cast(static_cast(data_ptr))); + attr_data.emplace_back(static_cast(data)); + data_ptr += sizeof(int32_t); + } + break; + } + case tflite::TensorType_INT64: { + for (int i = 0; i < count; i++) { + int64_t data = *(static_cast(static_cast(data_ptr))); + attr_data.emplace_back(static_cast(data)); + data_ptr += sizeof(int64_t); + } + break; + } + case tflite::TensorType_FLOAT32: { + for (int i = 0; i < count; i++) { + float data = *(static_cast(static_cast(data_ptr))); + attr_data.emplace_back(static_cast(data)); + data_ptr += sizeof(float); + } + break; + } + } + return RET_OK; + } + + protected: + bool isQuantizedModel(); + + protected: + const std::string &name; + bool quantizedModel; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_NODE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.cc new file mode 100644 index 00000000000..e6b7b4dc719 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +TfliteNodeParserRegistry::TfliteNodeParserRegistry() {} + +TfliteNodeParserRegistry::~TfliteNodeParserRegistry() {} + +TfliteNodeParserRegistry *TfliteNodeParserRegistry::GetInstance() { + static TfliteNodeParserRegistry instance; + return &instance; +} + +TfliteNodeParser *TfliteNodeParserRegistry::GetNodeParser(const std::string &name) { + auto it = parsers.find(name); + if (it != parsers.end()) { + return it->second; + } + return nullptr; +} +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h new file mode 100644 index 00000000000..8649c3cf6aa --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H +#define MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H + +#include +#include +#include "tools/common/node_util.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" + +namespace mindspore { +namespace lite { +class TfliteNodeParserRegistry { + public: + TfliteNodeParserRegistry(); + + virtual ~TfliteNodeParserRegistry(); + + static TfliteNodeParserRegistry *GetInstance(); + + TfliteNodeParser *GetNodeParser(const std::string &name); + + std::unordered_map parsers; +}; + +class TfliteNodeRegister { + public: + TfliteNodeRegister(const std::string &name, TfliteNodeParser *parser) { + TfliteNodeParserRegistry::GetInstance()->parsers[name] = parser; + } +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TOOLS_LITE_CONVERTER_PARSER_TFLITE_NODE_PARSER_REGISTRY_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_relu6_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_relu6_parser.cc new file mode 100644 index 00000000000..bfdfb5ea31f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_relu6_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_relu6_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteActivationParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + // MS_LOGI("paser TfliteActivationParser"); + std::unique_ptr attr(new schema::ActivationT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteActivationParser("Relu6", new TfliteActivationParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_relu6_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_relu6_parser.h new file mode 100644 index 00000000000..4ef26e8f4a9 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_relu6_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_RELU6_PARSER_H +#define PREDICT_TFLITE_RELU6_PARSER_H + +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" +#include +#include + +namespace mindspore { +namespace lite { +class TfliteActivationParser : public TfliteNodeParser { + public: + TfliteActivationParser() : TfliteNodeParser("Relu6") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_RELU6_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc new file mode 100644 index 00000000000..31ed316c78c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteReshapeParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) { + // MS_LOGD("parse TfliteReshapeParser"); + std::unique_ptr attr(new schema::ReshapeT()); + + const auto &tfliteAttr = tfliteOp->builtin_options.AsReshapeOptions(); + if (tfliteAttr == nullptr) { + // MS_LOGE("get op: %s attr failed", op->name.c_str()); + return RET_NULL_PTR; + } + + attr->format = schema::Format_NHWC; + attr->shape.resize(tfliteAttr->new_shape.size()); + for (size_t i = 0; i < tfliteAttr->new_shape.size(); ++i) { + attr->shape[i] = tfliteAttr->new_shape[i]; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Reshape; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteReshapeParser("Reshape", new TfliteReshapeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h new file mode 100644 index 00000000000..a4237d3f059 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_RESHAPE_PARSER_H +#define PREDICT_TFLITE_RESHAPE_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteReshapeParser : public TfliteNodeParser { + public: + TfliteReshapeParser() : TfliteNodeParser("Reshape") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_ADD_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_bilinear_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_bilinear_parser.cc new file mode 100644 index 00000000000..a854afae42f --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_bilinear_parser.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_resize_bilinear_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteResizeBilinearParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + // MS_LOGD("parse TfliteResizeBilinearParser"); + std::unique_ptr attr(new schema::ResizeT()); + const auto &tfliteAttr = tfliteOp->builtin_options.AsResizeBilinearOptions(); + if (tfliteAttr == nullptr) { + // MS_LOGE("get op: %s attr failed", op->name.c_str()); + return RET_NULL_PTR; + } + + attr->method = schema::ResizeMethod_BILINEAR; + attr->alignCorners = tfliteAttr->align_corners; + auto tfliteResizeTensorIndex = tfliteOp->inputs[1]; + auto resizeTensorBufferIndex = tfliteTensors.at(tfliteResizeTensorIndex)->buffer; + auto buffData = reinterpret_cast(tfliteModelBuffer.at(resizeTensorBufferIndex)->data.data()); + auto height = buffData[0]; + auto width = buffData[1]; + attr->newWidth = width; + attr->newHeight = height; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Resize; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteResizeBilinearParser("ResizeBilinear", new TfliteResizeBilinearParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_bilinear_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_bilinear_parser.h new file mode 100644 index 00000000000..79c744941be --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_bilinear_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_RESIZE_PARSER_H +#define PREDICT_TFLITE_RESIZE_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteResizeBilinearParser : public TfliteNodeParser { + public: + TfliteResizeBilinearParser() : TfliteNodeParser("ResizeBilinear") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_ADD_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_rsqrt_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_rsqrt_parser.cc new file mode 100644 index 00000000000..7b31e064624 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_rsqrt_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_rsqrt_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteRsqrtParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + // MS_LOGI("paser TfliteRsqrtParser"); + std::unique_ptr attr(new schema::RsqrtT()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Rsqrt; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteRsqrtParser("Rsqrt", new TfliteRsqrtParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_rsqrt_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_rsqrt_parser.h new file mode 100644 index 00000000000..9721e3b3fef --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_rsqrt_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_RSQRT_PARSER_H +#define PREDICT_TFLITE_RSQRT_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteRsqrtParser : public TfliteNodeParser { + public: + TfliteRsqrtParser() : TfliteNodeParser("Rsqrt") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_RSQRT_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc new file mode 100644 index 00000000000..dc8661330f4 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteSliceParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) { + // MS_LOGI("paser TfliteSliceParser"); + std::unique_ptr attr(new schema::SliceT()); + + if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->begin)) { + return RET_ERROR; + } + if (GetTfliteData(tfliteOp->inputs[2], tfliteTensors, tfliteModelBuffer, attr->size)) { + return RET_ERROR; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Slice; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteSliceParser("Slice", new TfliteSliceParser()); +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h new file mode 100644 index 00000000000..965b128b679 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_SLICE_PARSER_H +#define PREDICT_TFLITE_SLICE_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSliceParser : public TfliteNodeParser { + public: + TfliteSliceParser() : TfliteNodeParser("Slice") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_SLICE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc new file mode 100644 index 00000000000..b1d2c54224a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteSoftmaxParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) { + // MS_LOGI("paser TfliteSoftmaxParser"); + std::unique_ptr attr(new schema::SoftMaxT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsSoftmaxOptions(); + if (tflite_attr == nullptr) { + // MS_LOGE("get op: %s attr failed", op->name.c_str()); + } + // attr->axis + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_SoftMax; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteSoftmaxParser("Softmax", new TfliteSoftmaxParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h new file mode 100644 index 00000000000..728ed6bacf4 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_CONV_PARSER_H +#define PREDICT_TFLITE_CONV_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSoftmaxParser : public TfliteNodeParser { + public: + TfliteSoftmaxParser() : TfliteNodeParser("Softmax") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_CONV_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_squareddifference_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_squareddifference_parser.cc new file mode 100644 index 00000000000..7594bebff67 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_squareddifference_parser.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_squareddifference_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteSquaredDifferenceParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + // MS_LOGI("paser TfliteSquaredDifferenceParser"); + std::unique_ptr attr(new schema::SquaredDifferenceT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsSquaredDifferenceOptions(); + if (tflite_attr == nullptr) { + // MS_LOGE("get op: %s attr failed", op->name.c_str()); + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_SquaredDifference; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteSquaredDifferenceParser("SquaredDifference", new TfliteSquaredDifferenceParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_squareddifference_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_squareddifference_parser.h new file mode 100644 index 00000000000..6bebb64560e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_squareddifference_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_SQUAREDDIFFERENCE_PARSER_H +#define PREDICT_TFLITE_SQUAREDDIFFERENCE_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSquaredDifferenceParser : public TfliteNodeParser { + public: + TfliteSquaredDifferenceParser() : TfliteNodeParser("SquaredDifference") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_SQUAREDDIFFERENCE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc new file mode 100644 index 00000000000..e67ad9a3697 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteStackParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + // MS_LOGI("paser TfliteStackParser"); + std::unique_ptr attr(new schema::StackT()); + const auto &tflite_attr = tfliteOp->builtin_options.AsPackOptions(); + if (tflite_attr == nullptr) { + // MS_LOGE("get op: %s attr failed", op->name.c_str()); + } + + attr->axis = tflite_attr->axis; + attr->n = tflite_attr->values_count; + attr->isScale.assign(tfliteTensors[tfliteOp->inputs[0]]->shape.begin(), + tfliteTensors[tfliteOp->inputs[0]]->shape.end()); + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Stack; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteStackParser("Stack", new TfliteStackParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h new file mode 100644 index 00000000000..b659d671e49 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_STACK_PARSER_H +#define PREDICT_TFLITE_STACK_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteStackParser : public TfliteNodeParser { + public: + TfliteStackParser() : TfliteNodeParser("Stack") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_STACK_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sub_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_sub_parser.cc new file mode 100644 index 00000000000..66dc9c3c874 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sub_parser.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_sub_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteSubParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) { + // MS_LOGD("parse TfliteSubParser"); + std::unique_ptr attr(new schema::SubT()); + auto weight_index = tfliteOp->inputs[1]; + const auto &weight_tensor = tfliteTensors[weight_index]; + std::vector weight_tensors{weight_tensor.get()}; + + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { + // MS_LOGE("parse weight failed"); + return RET_ERROR; + } + + const auto &tfliteAttr = tfliteOp->builtin_options.AsSubOptions(); + if (tfliteAttr == nullptr) { + // MS_LOGE("get op: %s attr failed", op->name.c_str()); + return RET_NULL_PTR; + } + // tfliteAttr->fused_activation_function + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Sub; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteSubParser("Sub", new TfliteSubParser()); +} // namespace lite +} // namespace mindspore + + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sub_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_sub_parser.h new file mode 100644 index 00000000000..fd25ed3fead --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sub_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_SUB_PARSER_H +#define PREDICT_TFLITE_SUB_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteSubParser : public TfliteNodeParser { + public: + TfliteSubParser() : TfliteNodeParser("Sub") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_SUB_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_tanh_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_tanh_parser.cc new file mode 100644 index 00000000000..70c0e15faa3 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_tanh_parser.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/tflite/tflite_tanh_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteTanhParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + MS_LOG(DEBUG) << "parse TfliteTanhParser"; + std::unique_ptr attr(new schema::ActivationT()); + attr->type = schema::ActivationType_TANH; + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_tanh_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_tanh_parser.h new file mode 100644 index 00000000000..4449589ff4e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_tanh_parser.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_TANH_PARSER_H +#define PREDICT_TFLITE_TANH_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteTanhParser : public TfliteNodeParser { + public: + TfliteTanhParser() : TfliteNodeParser("Tanh") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, + bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_TANH_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc new file mode 100644 index 00000000000..4ebcbe093a6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h" + +namespace mindspore { +namespace lite { +STATUS TfliteTransposeParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + // MS_LOGD("parse TfliteTransposeParser"); + std::unique_ptr attr(new schema::TransposeT()); + if (GetTfliteData(tfliteOp->inputs[1], tfliteTensors, tfliteModelBuffer, attr->perm)) { + return RET_ERROR; + } + + auto weight_index = tfliteOp->inputs[1]; + const auto &weight_tensor = tfliteTensors[weight_index]; + std::vector weight_tensors{weight_tensor.get()}; + + if (RET_OK != ParseWeight(weight_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) { + // MS_LOGE("parse weight failed"); + return RET_ERROR; + } + + const auto &tfliteAttr = tfliteOp->builtin_options.AsTransposeOptions(); + if (tfliteAttr == nullptr) { + // MS_LOGE("get op: %s attr failed", op->name.c_str()); + return RET_NULL_PTR; + } + + if (op != nullptr) { + op->primitive = std::make_unique(); + op->primitive->value.type = schema::PrimitiveType_Transpose; + op->primitive->value.value = attr.release(); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteTransposeParser("Transpose", new TfliteTransposeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h new file mode 100644 index 00000000000..89c57c2bc88 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PREDICT_TFLITE_TRANSPOSE_PARSER_H +#define PREDICT_TFLITE_TRANSPOSE_PARSER_H + +#include +#include +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h" +#include "mindspore/lite/tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteTransposeParser : public TfliteNodeParser { + public: + TfliteTransposeParser() : TfliteNodeParser("Transpose") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // PREDICT_TFLITE_TRANSPOSE_PARSER_H + diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc new file mode 100644 index 00000000000..9b9532c9ff8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/parser/tflite/tflite_util.h" +#include +#include +#include "utils/log_adapter.h" + +namespace mindspore { +namespace lite { +std::map tfMsActivationFunctionMap{ + {tflite::ActivationFunctionType_NONE, schema::ActivationType_NO_ACTIVATION}, + {tflite::ActivationFunctionType_RELU, schema::ActivationType_RELU}, + {tflite::ActivationFunctionType_RELU6, schema::ActivationType_RELU6}, +}; + +schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType) { + return tfMsActivationFunctionMap.at(tfliteAFType); +} + +std::map tfMsOpTypeMap{ + {tflite::BuiltinOperator_CONV_2D, "Conv2D"}, + {tflite::BuiltinOperator_DEPTHWISE_CONV_2D, "DepthwiseConv2D"}, + {tflite::BuiltinOperator_AVERAGE_POOL_2D, "MeanPooling"}, + {tflite::BuiltinOperator_MAX_POOL_2D, "MaxPooling"}, + {tflite::BuiltinOperator_ADD, "Add"}, + {tflite::BuiltinOperator_CONCATENATION, "Concat"}, + {tflite::BuiltinOperator_RESIZE_BILINEAR, "ResizeBilinear"}, + {tflite::BuiltinOperator_RESHAPE, "Reshape"}, + {tflite::BuiltinOperator_LOGISTIC, "Logistic"}, + {tflite::BuiltinOperator_MUL, "Mul"}, + {tflite::BuiltinOperator_SOFTMAX, "Softmax"}, + {tflite::BuiltinOperator_FULLY_CONNECTED, "FullyConnected"}, + {tflite::BuiltinOperator_SLICE, "Slice"}, + {tflite::BuiltinOperator_SUB, "Sub"}, + {tflite::BuiltinOperator_TRANSPOSE, "Transpose"}, + {tflite::BuiltinOperator_PACK, "Stack"}, + {tflite::BuiltinOperator_MEAN, "Mean"}, + {tflite::BuiltinOperator_RELU6, "Relu6"}, + {tflite::BuiltinOperator_TANH, "Tanh"}, + {tflite::BuiltinOperator_RSQRT, "Rsqrt"}, + {tflite::BuiltinOperator_ARG_MAX, "Argmax"}, + {tflite::BuiltinOperator_SQUARED_DIFFERENCE, "SquaredDifference"}, + {tflite::BuiltinOperator_FAKE_QUANT, "FakeQuant"}, +}; + +std::string GetMSOpType(tflite::BuiltinOperator tfliteOpType) { + auto iter = tfMsOpTypeMap.find(tfliteOpType); + if (iter == tfMsOpTypeMap.end()) { + return "unsupported_op_type"; + } + return iter->second; +} + +std::map type_map = { + {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, + {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, + {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, +}; + +TypeId GetTfliteDataType(const tflite::TensorType &tflite_data_type) { + auto iter = type_map.find(tflite_data_type); + if (iter == type_map.end()) { + return kTypeUnknown; + } + return iter->second; +} + +schema::PadMode GetPadMode(tflite::Padding tflite_padmode) { + if (tflite_padmode == tflite::Padding_SAME) { + return schema::PadMode_SAME; + } else if (tflite_padmode == tflite::Padding_VALID) { + return schema::PadMode_VALID; + } else { + return schema::PadMode_NOTSET; + } +} + +size_t GetDataTypeSize(const TypeId &data_type) { + switch (data_type) { + case TypeId::kNumberTypeFloat32: + return sizeof(float); + case TypeId::kNumberTypeFloat16: + return sizeof(float) >> 1; + case TypeId::kNumberTypeInt8: + return sizeof(int8_t); + case TypeId::kNumberTypeInt32: + return sizeof(int); + case TypeId::kNumberTypeUInt8: + return sizeof(uint8_t); + case TypeId::kNumberTypeUInt32: + return sizeof(uint32_t); + default: + MS_LOG(ERROR) << "unsupport datatype"; + } +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.h b/mindspore/lite/tools/converter/parser/tflite/tflite_util.h new file mode 100644 index 00000000000..3598c2f3a24 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_TFLITE_UTIL_H +#define MS_TFLITE_UTIL_H + + +#include +#include "utils/log_adapter.h" +#include "schema/inner/model_generated.h" +#include "tools/converter/parser/tflite/schema_generated.h" +#include "schema/inner/ops_generated.h" +#include "ir/dtype/type_id.h" + +// using namespace std; + +namespace mindspore { +namespace lite { +schema::PadMode GetPadMode(tflite::Padding tflite_padmode); + +size_t GetDataTypeSize(const TypeId &data_type); + +schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType); + +std::string GetMSOpType(tflite::BuiltinOperator tfliteOpType); + +TypeId GetTfliteDataType(const tflite::TensorType &tflite_data_type); +} // namespace lite +} // namespace mindspore + +#endif // MS_TFLITE_UTIL_H + diff --git a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt new file mode 100644 index 00000000000..d7fba0631a2 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt @@ -0,0 +1,19 @@ +set(3RD_DIR ../../../third_party) +include_directories(${3RD_DIR}/protobuf/build/include) +include_directories(${3RD_DIR}/flatbuffers/include) +include_directories(${3RD_DIR}/opencv/build/include/opencv4) + +add_library(quantizer_mid OBJECT + #${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc + ${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc + #${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc + ${CMAKE_CURRENT_SOURCE_DIR}/post_training.cc + #${CMAKE_CURRENT_SOURCE_DIR}/../proto/post_training/post_training.pb.cc + ) + +if(ENABLE_ASAN) + target_link_libraries(quantizer_mid libasan libSecodefuzz) +endif() diff --git a/mindspore/lite/tools/converter/quantizer/general_bitpacking.cc b/mindspore/lite/tools/converter/quantizer/general_bitpacking.cc new file mode 100644 index 00000000000..3893b21c766 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/general_bitpacking.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/quantizer/general_bitpacking.h" + +namespace mindspore { +namespace lite { +BitPack::BitPack(const uint8_t& bitnum) {this->bitnum = bitnum;} +void BitPack::UnPackFromUint8ToOrigin(uint8_t& n, std::queue& unpackBitData) { + int bitCount = 0; + while (bitCount < 8) { + bool a = n % 2; + n = n >> 1; + bitCount++; + unpackBitData.push(a); + } +} +void BitPack::UnPack(uint8_t bitnum, uint8_t& packedData, + std::vector &originData, std::queue& unpackBitData) { + UnPackFromUint8ToOrigin(packedData, unpackBitData); + // std::queue unpackBitTmpData; + + while (unpackBitData.size() > bitnum) { + uint32_t result = 0; + for (int k = 0; k < bitnum; k++) { + bool bitTmp = unpackBitData.front(); + result = (result << 1) + static_cast(bitTmp); + unpackBitData.pop(); + } + originData.push_back(result); + } +} +void BitPack::PackFromOriginToUint8(std::stack& ans, std::vector& packedDataVec) { + uint32_t result = 0; + for (size_t i = 0; i < 8; i++) { + bool bit_tmp = ans.top(); + result = (result << 1) + static_cast(bit_tmp); + ans.pop(); + } + packedDataVec.push_back(result); +} +void BitPack::DoBinary(uint8_t& n, std::stack& ans, std::vector& packedDataVec) { + int bitCount = 0; + while (bitCount < bitnum) { + bool a = n / (1 << (unsigned int)(bitnum - bitCount - 1)); + n = n - a * (1 << (unsigned int)(bitnum - bitCount - 1)); + bitCount++; + ans.push(a); + if (ans.size() == 8) { + PackFromOriginToUint8(ans, packedDataVec); + } + } +} + +void BitPack::BitPacking(const std::vector& originDataVec, std::vector& packedDataVec) { + std::stack bitDataVec; + for (size_t i = 0; i < originDataVec.size(); i++) { + uint8_t tmp = originDataVec[i]; + DoBinary(tmp, bitDataVec, packedDataVec); + } + + size_t remainBitData = bitDataVec.size(); + if ( 8 > remainBitData && remainBitData > 0 ) { + for ( int i = 0; i < 8 - remainBitData; i++ ) { + bitDataVec.push(0); + } + PackFromOriginToUint8(bitDataVec, packedDataVec); + } +} + +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/quantizer/general_bitpacking.h b/mindspore/lite/tools/converter/quantizer/general_bitpacking.h new file mode 100644 index 00000000000..284c6e028d7 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/general_bitpacking.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_GENERAL_BITPACKING_H +#define MINDSPORE_GENERAL_BITPACKING_H +#include +#include +#include +#include +#include + +namespace mindspore { +namespace lite { +class BitPack { + public: + explicit BitPack(const uint8_t &bitbum = 8); + ~BitPack() = default; + void BitPacking(const std::vector &originDataVec, std::vector &packedDataVec); + void UnPack(uint8_t bitnum, uint8_t &packedData, std::vector &originData, std::queue &unpackBitData); + + private: + void UnPackFromUint8ToOrigin(uint8_t &n, std::queue &unpackBitData); + void PackFromOriginToUint8(std::stack &ans, std::vector &packedDataVec); + void DoBinary(uint8_t &n, std::stack &ans, std::vector &packed_data_vec); + uint8_t bitnum; +}; +} // namespace lite +} // namespace mindspore + +#endif diff --git a/mindspore/lite/tools/converter/quantizer/post_training.cc b/mindspore/lite/tools/converter/quantizer/post_training.cc new file mode 100644 index 00000000000..b96c1967f33 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/post_training.cc @@ -0,0 +1,926 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "schema/inner/model_generated.h" +#include "src/ir/tensor.h" +#include "src/common/anf_exporter/anf_exporter.h" +#include "tools/converter/quantizer/post_training.h" +#include "tools/converter/quantizer/quantize_util.h" +#include "src/common/common.h" +#include "utils/log_adapter.h" +#include "securec/include/securec.h" +#include "tools/common/tensor_util.h" +#include "src/common/file_utils.h" + +using std::string; +using std::vector; + +namespace mindspore { +namespace lite { +namespace quant { + +struct DivergInfo { + std::vector histogram; + CNodePtr cnode; + int bin_num; + float interval = 0; + float max; + float min; + float best_T = 0.0f; + size_t bit_num; + int quant_max = 255; + int quant_min = 0; + DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max = 255, int quant_min = 0) { + this->cnode = cnode; + this->bin_num = bins; + this->bit_num = bits; + histogram.resize(bin_num); + max = FLT_MIN; + min = FLT_MAX; + this->quant_max = quant_max; + this->quant_min = quant_min; + std::fill(histogram.begin(), histogram.end(), 1.0e-7); + } + + STATUS RecordMaxValue(const std::vector &datas) { + for (float data : datas) { + max = std::max(data, max); + min = std::min(data, min); + } + return RET_OK; + } + + void UpdateInterval() { + auto max_value = std::max(fabs(this->max), fabs(this->min)); + this->interval = max_value / static_cast(bin_num); + } + + STATUS UpdateHistogram(const std::vector &data, const std::vector &shape) { + for (auto value : data) { + int bin_index = std::min(static_cast(std::fabs(value) / this->interval), bin_num - 1); + this->histogram[bin_index]++; + } + return RET_OK; + } + + void DumpHistogram() { + MS_LOG(INFO) << "Print node " << cnode->fullname_with_scope() << " histogram"; + for (float item : this->histogram) { + std::cout << item << " "; + } + std::cout << std::endl; + } + + STATUS ComputeThreshold() { + constexpr int quant_bint_nums = 128; + int threshold = quant_bint_nums; + float min_kl = FLT_MAX; + float after_threshold_sum = std::accumulate(this->histogram.begin() + quant_bint_nums, this->histogram.end(), 0.0f); + + for (int i = quant_bint_nums; i < this->bin_num; ++i) { + std::vector quantized_histogram(quant_bint_nums, 0); + std::vector reference_histogram(this->histogram.begin(), this->histogram.begin() + i); + std::vector expanded_histogram(i, 0); + reference_histogram[i - 1] += after_threshold_sum; + after_threshold_sum -= this->histogram[i]; + + const float bin_interval = static_cast(i) / static_cast(quant_bint_nums); + + // merge i bins to target bins + for (int j = 0; j < quant_bint_nums; ++j) { + const float start = j * bin_interval; + const float end = start + bin_interval; + const int left_upper = static_cast(std::ceil(start)); + if (left_upper > start) { + const double left_scale = left_upper - start; + quantized_histogram[j] += left_scale * this->histogram[left_upper - 1]; + } + const int right_lower = static_cast(std::floor(end)); + if (right_lower < end) { + const double right_scale = end - right_lower; + quantized_histogram[j] += right_scale * this->histogram[right_lower]; + } + std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower, + [&quantized_histogram, j](float item) { quantized_histogram[j] += item; }); + } + // expand target bins to i bins in order to calculate KL with reference_histogram + for (int j = 0; j < quant_bint_nums; ++j) { + const float start = j * bin_interval; + const float end = start + bin_interval; + float count = 0; + const int left_upper = static_cast(std::ceil(start)); + float left_scale = 0.0f; + if (left_upper > start) { + left_scale = left_upper - start; + if (this->histogram[left_upper - 1] != 0) { + count += left_scale; + } + } + const int right_lower = static_cast(std::floor(end)); + double right_scale = 0.0f; + if (right_lower < end) { + right_scale = end - right_lower; + if (this->histogram[right_lower] != 0) { + count += right_scale; + } + } + std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower, + [&count](float item) { + if (item != 0) { + count += 1; + } + }); + if (count == 0) { + continue; + } + const float average_num = quantized_histogram[j] / count; + if (left_upper > start && this->histogram[left_upper - 1] != 0) { + expanded_histogram[left_upper - 1] += average_num * left_scale; + } + if (right_lower < end && this->histogram[right_lower] != 0) { + expanded_histogram[right_lower] += average_num * right_scale; + } + for (int k = left_upper; k < right_lower; ++k) { + if (this->histogram[k] != 0) { + expanded_histogram[k] += average_num; + } + } + } + auto KLDivergence = [](std::vector p, std::vector q) { + auto sum = 0.0f; + std::for_each(p.begin(), p.end(), [&sum](float item) { sum += item; }); + std::for_each(p.begin(), p.end(), [sum](float &item) { item /= sum; }); + sum = 0.0f; + std::for_each(q.begin(), q.end(), [&sum](float item) { sum += item; }); + std::for_each(q.begin(), q.end(), [sum](float &item) { item /= sum; }); + + float result = 0.0f; + const int size = p.size(); + for (int i = 0; i < size; ++i) { + if (p[i] != 0) { + if (q[i] == 0) { + result += 1.0f; + } else { + result += (p[i] * std::log((p[i]) / (q[i]))); + } + } + } + return result; + }; + const float kl = KLDivergence(reference_histogram, expanded_histogram); + if (kl < min_kl) { + min_kl = kl; + threshold = i; + } + } + MS_LOG(DEBUG) << "Best threshold bin index: " << threshold; + this->best_T = (static_cast(threshold) + 0.5f) * this->interval; + return RET_OK; + } + + std::pair GetScale() { + float max_value = this->best_T; + float min_value = -max_value; + MS_ASSERT(quant_max - quant_min != 0); + double scale = (max_value - min_value) / (quant_max - quant_min); + MS_ASSERT(scale != 0); + return std::make_pair(this->cnode, scale); + } + + std::pair GetZeropoint() { + float max_value = this->best_T; + float min_value = -max_value; + MS_ASSERT(quant_max - quant_min != 0); + float scale = (max_value - min_value) / (quant_max - quant_min); + + auto quant_min_float = static_cast(quant_min); + auto quant_max_float = static_cast(quant_max); + MS_ASSERT(scale != 0); + const float zero_point_from_min = quant_min_float - min_value / scale; + // const float zero_point_from_max = quant_max_float - max_value / scale; + int zero_point; + if (zero_point_from_min < quant_min_float) { + zero_point = quant_min; + } else if (zero_point_from_min > quant_max_float) { + zero_point = quant_max; + } else { + zero_point = static_cast(std::round(zero_point_from_min)); + } + return std::make_pair(this->cnode, zero_point); + } +}; +std::unordered_map Calibrator::GetResult( + std::unordered_map> *diverg_info) { + std::unordered_map result; + for (auto iter = diverg_info->begin(); iter != diverg_info->end(); iter++) { + DivergInfo *info = iter->second.get(); + auto item = info->GetScale(); + result.insert(item); + } + return result; +} +std::unordered_map Calibrator::GetZeropoint( + std::unordered_map> *mDivergInfo) { + std::unordered_map result; + for (auto iter = mDivergInfo->begin(); iter != mDivergInfo->end(); iter++) { + DivergInfo *info = iter->second.get(); + auto zeropoint = info->GetZeropoint(); + result.insert(zeropoint); + } + return result; +} + +std::map Calibrator::GetMinMax( + std::unordered_map> *mDivergInfo) { + std::map result; + for (auto iter = mDivergInfo->begin(); iter != mDivergInfo->end(); iter++) { + DivergInfo *info = iter->second.get(); + mindspore::lite::quant::MaxMin input_maxmin{}; + input_maxmin.min = info->min; + input_maxmin.max = info->max; + result[info->cnode] = input_maxmin; + } + return result; +} + +void Calibrator::Dump() { + for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) { + DivergInfo *info = iter->second.get(); + info->DumpHistogram(); + } +} + +std::unordered_map> *Calibrator::GetInputDivergInfo() { + return &this->input_diverg_info_; +} + +std::unordered_map> *Calibrator::GetOutputDivergInfo() { + return &this->output_diverg_info_; +} + +STATUS Calibrator::RecordMaxValue(std::string opName, vector data, + std::unordered_map> *mDivergInfo) { + auto got = (*mDivergInfo).find(opName); + if (got != (*mDivergInfo).end()) { + ((*got).second)->RecordMaxValue(data); + } + return RET_OK; +} + +STATUS Calibrator::ComputeThreshold() { + for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) { + DivergInfo *info = iter->second.get(); + info->ComputeThreshold(); + } + for (auto iter = this->output_diverg_info_.begin(); iter != this->output_diverg_info_.end(); iter++) { + DivergInfo *info = iter->second.get(); + info->ComputeThreshold(); + } + return RET_OK; +} + +STATUS Calibrator::UpdateDivergInverval(std::unordered_map> *diverg_info) { + for (auto iter = (*diverg_info).begin(); iter != (*diverg_info).end(); iter++) { + DivergInfo *info = iter->second.get(); + info->UpdateInterval(); + } + return RET_OK; +} + +STATUS Calibrator::UpdateDataFrequency(std::string op_name, vector data, vector shape, + std::unordered_map> *diverg_info) { + auto got = (*diverg_info).find(op_name); + if (got != (*diverg_info).end()) { + ((*got).second)->UpdateHistogram(data, shape); + } + return RET_OK; +} + +STATUS Calibrator::AddQuantizedOp(CNodePtr node) { + if (node == nullptr) { + MS_LOG(ERROR) << "To be quantized node is null"; + return RET_ERROR; + } + string node_name = node->fullname_with_scope(); + std::unique_ptr input_diverg = + std::unique_ptr(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_)); + std::unique_ptr output_diverg = + std::unique_ptr(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_)); + + input_diverg_info_.insert(std::make_pair(string(node_name), std::move(input_diverg))); + output_diverg_info_.insert(std::make_pair(string(node_name), std::move(output_diverg))); + return RET_OK; +} + +void Calibrator::AddImage(const string file) { + auto exist = [](const string file) { + struct stat buf; + return stat(file.c_str(), &buf) == 0; + }; + if (exist(file)) { + MS_LOG(INFO) << "load image: " << file; + this->images_.push_back(file); + } else { + MS_LOG(WARNING) << "Invaild image file path: " << file; + } +} + +STATUS Calibrator::GenerateInputData(const int index, mindspore::tensor::MSTensor *tensor) const { + string path = images_[index]; + MS_LOG(INFO) << "read image: " << path; + size_t size; + char *binBuf = ReadFile(path.c_str(), &size); + + // auto *rawinputDatas = reinterpret_cast(binBuf); + // auto mobilenet_input = const_cast(rawinputDatas); + auto data = tensor->MutableData(); + memcpy(data, binBuf, size); + + // tensor->SetData(mobilenet_input); + return RET_OK; +} + +STATUS Calibrator::CollectImages() { + // check image file path + DIR *root = opendir(config_param_.image_path.c_str()); + if (root == nullptr) { + MS_LOG(ERROR) << "invalid image path: " << config_param_.image_path; + return RET_PARAM_INVALID; + } + struct dirent *image_dir = readdir(root); + int count = 0; + while (image_dir != nullptr) { + if (image_dir->d_name[0] != '.') { + const std::string file_name = config_param_.image_path + "/" + image_dir->d_name; + if (config_param_.batch_count == 0) { + this->AddImage(file_name); + count++; + } else if (count < config_param_.batch_count) { + this->AddImage(file_name); + count++; + } else { + break; + } + } + image_dir = readdir(root); + } + closedir(root); + return RET_OK; +} + +STATUS Calibrator::ReadConfig() { + if (config_path_.empty() || config_path_.length() > PATH_MAX) { + MS_LOG(ERROR) << "invalid config path!"; + return RET_PARAM_INVALID; + } + // check whether config file path is valid + char *resolved_path = new (std::nothrow) char[PATH_MAX]{0}; + if (resolved_path == nullptr) { + MS_LOG(ERROR) << "New an object failed."; + return RET_ERROR; + } + if (nullptr != realpath(config_path_.c_str(), resolved_path)) { + config_path_ = string(resolved_path); + } + std::ifstream fs(config_path_.c_str(), std::ifstream::in); + if (!fs.is_open()) { + MS_LOG(ERROR) << "config proto file %s open failed: " << config_path_; + delete[] resolved_path; + return RET_PARAM_INVALID; + } + std::string line; + while (std::getline(fs, line)) { + auto index = line.find('='); + if (index == std::string::npos) { + MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check"; + delete[] resolved_path; + return RET_PARAM_INVALID; + } + auto key = line.substr(0, index); + auto value = line.substr(index + 1); + if (key == "image_path") { + config_param_.image_path = value; + } else if (key == "batch_count") { + config_param_.batch_count = std::stoul(value); + } else if (key == "thread_num") { + config_param_.thread_num = std::stoul(value); + } else { + MS_LOG(WARNING) << "unsupported parameter"; + } + } + MS_LOG(INFO) << "image_path: " << config_param_.image_path << " " + << "batch_count: " << config_param_.batch_count << " " + << "thread_num: " << config_param_.thread_num; + + delete[] resolved_path; + fs.close(); + return RET_OK; +} + +Calibrator::Calibrator(string path, size_t bitNum, int quantMax, int quantMin) + : config_path_(path), bit_num_(bitNum), quant_max_(quantMax), quant_min_(quantMin) {} + +PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type) + : Quantizer(graph) { + this->bit_num = bit_num; + calibrator_ = std::unique_ptr(new Calibrator(path, this->bit_num, quant_max, quant_min)); + if (calibrator_ == nullptr) { + MS_LOG(ERROR) << "creat calibrator failed!"; + return; + } + this->target_type_ = target_type; + if (target_type == kNumberTypeInt8) { + quant_max = (1 << (this->bit_num - 1)) - 1; // 127 + quant_min = -(1 << (this->bit_num - 1)); // -128 + } else if (target_type == kNumberTypeUInt8) { + quant_max = (1 << this->bit_num) - 1; // 255 + quant_min = 0; + } else { + MS_LOG(ERROR) << "unsupported quant value type: " << target_type; + } +} + +STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct MaxMin *max_min, + std::shared_ptr lite_primitive) { + if (!lite_primitive->GetInputQuantParams().empty()) { + return RET_OK; + } + schema::QuantParamT quant_param; + quant_param.scale = scale; + quant_param.zeroPoint = zeropoint; + quant_param.max = max_min->max; + quant_param.min = max_min->min; + quant_param.numBits = bit_num; + quant_param.narrowRange = false; + lite_primitive->AddInputQuantParam(quant_param); + // p->AddAttr("quant_input_dataType", MakeValue((int)DataType_DT_FLOAT)); + return RET_OK; +} + +STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min, + std::shared_ptr lite_primitive) { + if (!lite_primitive->GetOutputQuantParams().empty()) { + return RET_OK; + } + schema::QuantParamT quant_param; + quant_param.scale = scale; + quant_param.zeroPoint = zeropoint; + quant_param.max = max_min->max; + quant_param.min = max_min->min; + quant_param.numBits = bit_num; + quant_param.narrowRange = false; + lite_primitive->AddOutputQuantParam(quant_param); + // p->AddAttr("quant_output_dataType", MakeValue((int)DataType_DT_FLOAT)); + return RET_OK; +} + +STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr node) { + // const vector dims = filter->dims; + // perlayer + if (!node->isa()) { + MS_LOG(ERROR) << "not a parameter"; + return RET_PARAM_INVALID; + } + auto parameter = std::dynamic_pointer_cast(node); + ParamValueLitePtr paramValue = std::dynamic_pointer_cast(parameter->default_param()); + auto status = QuantFilter(paramValue, QuantType_PostTraining, bit_num); + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantFilter failed: " << status; + return status; + } + return RET_OK; +} + +STATUS PostTrainingQuantizer::DoBiasQuant(std::shared_ptr input, AnfNodePtr weight, AnfNodePtr bias) { + if (input == nullptr || weight == nullptr || bias == nullptr) { + MS_LOG(ERROR) << "null pointer!"; + return RET_NULL_PTR; + } + + ParameterPtr weightParameterPtr = std::dynamic_pointer_cast(weight); + auto default_param = weightParameterPtr->default_param(); + auto weight_param = std::dynamic_pointer_cast(default_param); + // std::vector> weight_quant_params = weight_param->get_quant_params(); + + ParameterPtr biasParameterPtr = std::dynamic_pointer_cast(bias); + auto bias_default_param = biasParameterPtr->default_param(); + auto bias_param = std::dynamic_pointer_cast(bias_default_param); + + vector input_scales; + vector filter_scales; + vector bias_scales; + auto quant_params = input->GetInputQuantParams(); + size_t sizeX = quant_params.size(); + for (size_t i = 0; i < sizeX; i++) { + input_scales.emplace_back(quant_params[i].scale); + } + size_t sizeY = weight_param->quant_param().size(); + if (sizeX != sizeY) { + if (sizeX > 1 && sizeY > 1) { + MS_LOG(ERROR) << "input and filter's scale count cannot match!"; + return RET_ERROR; + } + } + for (size_t i = 0; i < sizeY; i++) { + auto scale = weight_param->quant_param()[i]->scale; + filter_scales.push_back(scale); + } + size_t size = std::max(sizeX, sizeY); + for (size_t i = 0; i < size; i++) { + auto scaleX = sizeX > 1 ? input_scales[i] : input_scales[0]; + auto scaleY = sizeY > 1 ? filter_scales[i] : filter_scales[0]; + bias_scales.push_back(scaleX * scaleY); + } + MS_ASSERT(!bias_scales.empty()); + size_t shape_size = bias_param->tensor_shape_size(); + + // set bias quant param + bias_param->quant_param().clear(); + for (size_t i = 0; i < bias_scales.size(); i++) { + std::unique_ptr param(new (std::nothrow) AnfQuantParam()); + param->scale = bias_scales[i]; + param->zeroPoint = 0; + bias_param->quant_param().emplace_back(std::move(param)); + } + // quant bias data + int32_t *quant_datas = new (std::nothrow) int32_t[shape_size]; + if (quant_datas == nullptr) { + MS_LOG(ERROR) << "null pointer dereferencing."; + return RET_NULL_PTR; + } + float *raw_datas = reinterpret_cast(bias_param->tensor_addr()); + double bias_scale_tmp; + for (size_t i = 0; i < shape_size; i++) { + if (bias_scales.size() == 1) { + bias_scale_tmp = bias_scales[0]; + } else { + bias_scale_tmp = bias_scales[i]; + } + auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp); + quant_datas[i] = quant_data; + } + auto ret = + memcpy_s(bias_param->tensor_addr(), shape_size * sizeof(int32_t), quant_datas, shape_size * sizeof(int32_t)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s failed."; + delete[] quant_datas; + return RET_ERROR; + } + delete[] quant_datas; + bias_param->set_tensor_type(kNumberTypeInt32); + return RET_OK; +} + +// STATUS PostTrainingQuantizer::reformatConvWeight(GraphDefT *graph) { +// for (auto &subGraph : graphDefT->subgraphs) { +// for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end(); iter++) { +// OpDefT *node = (*iter).get(); +// bool isConv = false; +// kTransFilterType tansType; +// if ((*node).attr.type == OpT_Conv2D) { +// tansType = kKCHW2HWCK; +// isConv = true; +// } +// else if ((*node).attr.type == OpT_DepthwiseConv2D) { +// tansType = kCKHW2HWCK; +// isConv = true; +// } +// if (isConv) { +// auto status = TransFilterFormat(&(*subGraph.get()->allTensors.at(node->inputIndex[1])), +// tansType); +// if (status != RET_OK) { +// return status; +// } +// TensorDefT *weight = subGraph->allTensors.at(node->inputIndex[1]).get(); +// weight->format = Format_HWCK; +// PostBitPack(weight, bitNum); +// } +// } +// } +//} + +STATUS PostTrainingQuantizer::QuantNode() { + auto input_min_max = this->calibrator_->GetMinMax(this->calibrator_->GetInputDivergInfo()); + auto input_scale = this->calibrator_->GetResult(this->calibrator_->GetInputDivergInfo()); + auto input_zero_point = this->calibrator_->GetZeropoint(this->calibrator_->GetInputDivergInfo()); + + auto output_min_max = this->calibrator_->GetMinMax(this->calibrator_->GetOutputDivergInfo()); + auto output_scale = this->calibrator_->GetResult(this->calibrator_->GetOutputDivergInfo()); + auto output_zeropoint = this->calibrator_->GetZeropoint(this->calibrator_->GetOutputDivergInfo()); + + auto cnodes = funcGraph->GetOrderedCnodes(); + for (auto &cnode : cnodes) { + auto cnode_name = cnode->fullname_with_scope(); + if (this->calibrator_->GetInputDivergInfo()->find(cnode_name) == this->calibrator_->GetInputDivergInfo()->end()) { + MS_LOG(INFO) << cnode_name << " can not do quant"; + continue; + } + auto primitiveT_value = GetValueNode>(cnode->input(0)); + if (primitiveT_value == nullptr) { + MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; + continue; + } + + if (input_scale.find(cnode) == input_scale.end()) { + primitiveT_value->SetQuantType(schema::QuantType_QUANT_NONE); + continue; + } + auto input_vec = cnode->inputs(); + auto op_name = cnode->fullname_with_scope(); + MS_LOG(INFO) << "OpName: " << op_name; + if (input_vec.size() <= 3 && op_name != "Conv2D" && op_name != "DepthwiseConv2D") { + MS_LOG(INFO) << "todo(x): "; + // int32_t qnodeOutputZeropoint = outputZeropoint[cnode]; + // p->AddAttr(kInputTensorDataType, MakeValue((int)targetType)); + } else { + // do input quant + double scale = input_scale[cnode]; + int32_t convInputzeropoint = input_zero_point[cnode]; + DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitiveT_value); + // do weight quant + auto weight = cnode->input(2); + DoWeightQuant(weight); + // do bias quant + if (cnode->inputs().size() == 4) { + auto bias = cnode->input(3); + DoBiasQuant(primitiveT_value, weight, bias); + } + } + // do output quant + double OutputScale = output_scale[cnode]; + int32_t OutputZeropoint = output_zeropoint[cnode]; + DoQuantOutput(OutputScale, OutputZeropoint, &output_min_max[cnode], primitiveT_value); + primitiveT_value->SetQuantType(schema::QuantType_PostTraining); + } + return RET_OK; +} + +STATUS PostTrainingQuantizer::UpdateDivergInverval() { + this->calibrator_->UpdateDivergInverval(this->calibrator_->GetInputDivergInfo()); + this->calibrator_->UpdateDivergInverval(this->calibrator_->GetOutputDivergInfo()); + return RET_OK; +} + +/** + * Pre Process + * 1. generate config param + * 1.1 read config file + * 1.2 parse txt + * 2. collect image files + * 2.1 parse image files to input tensor + * 3. save quantied node + **/ +STATUS PostTrainingQuantizer::PreProcess() { + if (this->calibrator_ == nullptr) { + MS_LOG(ERROR) << "calibrator is null!"; + return RET_ERROR; + } + // 1. generate config param + STATUS status = calibrator_->ReadConfig(); + if (status != RET_OK) { + MS_LOG(ERROR) << "read proto text failed!"; + return status; + } + // 2. collect image files + status = calibrator_->CollectImages(); + if (status != RET_OK) { + MS_LOG(ERROR) << "collect images failed!"; + return status; + } + // 3. collect to be quantized operators + // from user input + QuantStrategy strategy(10); + auto cnodes = funcGraph->GetOrderedCnodes(); + for (auto cnode : cnodes) { + AnfNodePtr anf = std::dynamic_pointer_cast(cnode); + if (strategy.CanOpPostQuantized(anf)) { + MS_LOG(INFO) << "node: " << cnode->fullname_with_scope() << " will be quantized"; + calibrator_->AddQuantizedOp(cnode); + } + } + return RET_OK; +} + +STATUS PostTrainingQuantizer::CheckTensorVec(const std::string &nodeName, + const std::vector &tensorVec) const { + if (tensorVec.size() < 1) { + MS_LOG(ERROR) << "node: " << nodeName << " input tensors is 0"; + return RET_ERROR; + } + tensor::Tensor *tensor = tensorVec[0]; + if (tensor->data_type() != kNumberTypeFloat) { + //&& tensor->RefCount() != MSCONST_WEIGHT_REFCOUNT + MS_LOG(DEBUG) << "node: " << nodeName << " will not quantize"; + } + return RET_OK; +} + +/** + * 1. create input tensor + * 2. insert callback to session + * 3. run session + **/ +STATUS PostTrainingQuantizer::DoInference() { + for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) { + // TODO(x) when model has inputs count > 1 + // get input tensor + vector inputs = session_->GetInputs(); + if (inputs.size() > 1) { + MS_LOG(ERROR) << "model's input tensor size: " << inputs.size() << " >1"; + return RET_ERROR; + } + STATUS status = calibrator_->GenerateInputData(i, inputs.front()); + if (status != RET_OK) { + MS_LOG(ERROR) << "generate input data from images failed!"; + return RET_ERROR; + } + /** + * struct CallBackParam { + std::string nodeType; + NODE_ID nodeName; + std::unordered_set depends; + int opExecResult; + }; + */ + mindspore::kernel::KernelCallBack beforeCallBack = [&](const std::vector &beforeInputs, + const std::vector &beforeOutputs, + const mindspore::kernel::CallBackParam &callParam) -> bool { + if (PostTrainingQuantizer::CheckTensorVec(callParam.name_callback_aram, beforeInputs) != RET_OK) { + return false; + } + auto tensor = beforeInputs[0]; + const float *tData = static_cast(tensor->Data()); + size_t shapeSize = tensor->ElementsNum(); + vector data(tData, tData + shapeSize); + this->calibrator_->RecordMaxValue(callParam.name_callback_aram, data, this->calibrator_->GetInputDivergInfo()); + return true; + }; + // func + mindspore::kernel::KernelCallBack afterCallBack = [&](const std::vector &afterInputs, + const std::vector &afterOutputs, + const mindspore::kernel::CallBackParam &callParam) -> bool { + if (PostTrainingQuantizer::CheckTensorVec(callParam.name_callback_aram, afterOutputs) != RET_OK) { + return false; + } + auto tensor = afterOutputs[0]; + const float *tensor_data = static_cast(tensor->Data()); + size_t shape_size = tensor->ElementsNum(); + vector data(tensor_data, tensor_data + shape_size); + this->calibrator_->RecordMaxValue(callParam.name_callback_aram, data, this->calibrator_->GetOutputDivergInfo()); + return true; + }; + status = session_->RunGraph(beforeCallBack, afterCallBack); + if (status != RET_OK) { + MS_LOG(ERROR) << "run model failed!"; + return RET_ERROR; + } + } + return RET_OK; +} + +STATUS PostTrainingQuantizer::CollectDataFrequency() { + for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) { + // TODO(x) when model has inputs count > 1 + // get input tensor + vector inputs = session_->GetInputs(); + if (inputs.size() > 1) { + MS_LOG(ERROR) << "model's input tensor size: " << inputs.size() << " > 1"; + return RET_ERROR; + } + STATUS status = calibrator_->GenerateInputData(i, inputs.front()); + if (status != RET_OK) { + MS_LOG(ERROR) << "generate input data from images failed!"; + return RET_ERROR; + } + + mindspore::kernel::KernelCallBack beforeCallBack = [&](const std::vector &beforeInputs, + const std::vector &beforeOutputs, + const mindspore::kernel::CallBackParam &callParam) { + if (PostTrainingQuantizer::CheckTensorVec(callParam.name_callback_aram, beforeInputs) != RET_OK) { + return false; + } + auto tensor = beforeInputs[0]; + const float *tensor_data = static_cast(tensor->Data()); + size_t shape_size = tensor->ElementsNum(); + vector data(tensor_data, tensor_data + shape_size); + this->calibrator_->UpdateDataFrequency(callParam.name_callback_aram, data, tensor->shape(), + this->calibrator_->GetInputDivergInfo()); + return true; + }; + + mindspore::kernel::KernelCallBack afterCallBack = [&](const std::vector &after_inputs, + const std::vector &after_outputs, + const mindspore::kernel::CallBackParam &call_param) { + if (PostTrainingQuantizer::CheckTensorVec(call_param.name_callback_aram, after_outputs) != RET_OK) { + return false; + } + auto tensor = after_outputs[0]; + const float *tenosr_data = static_cast(tensor->Data()); + size_t shape_size = tensor->ElementsNum(); + vector data(tenosr_data, tenosr_data + shape_size); + this->calibrator_->UpdateDataFrequency(call_param.name_callback_aram, data, tensor->shape(), + this->calibrator_->GetOutputDivergInfo()); + return true; + }; + status = session_->RunGraph(beforeCallBack, afterCallBack); + if (status != RET_OK) { + MS_LOG(ERROR) << "run model failed!"; + return RET_ERROR; + } + } + + return RET_OK; +} + +STATUS PostTrainingQuantizer::ComputeThreshold() { return this->calibrator_->ComputeThreshold(); } + +STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr funcGraph) { + MS_LOG(INFO) << "start to parse config file"; + STATUS status = PreProcess(); + if (status != RET_OK) { + MS_LOG(ERROR) << "do pre process failed!"; + return status; + } + MS_LOG(INFO) << "start create session"; + flatbuffers::FlatBufferBuilder builder(1024); + auto offset = schema::MetaGraph::Pack(builder, Export(funcGraph)); + builder.Finish(offset); + size_t size = builder.GetSize(); + auto *content = reinterpret_cast(builder.GetBufferPointer()); + if (content == nullptr) { + MS_LOG(ERROR) << "GetBufferPointer nullptr"; + return RET_ERROR; + } + auto model = lite::Model::Import(content, size); + + Context ctx; + ctx.deviceCtx.type = DT_CPU; + ctx.threadNum = calibrator_->GetThreadNum(); + ctx.cpuBindMode = MID_CPU; + + session_ = dynamic_cast(session::LiteSession::CreateSession(&ctx)); + if (session_ == nullptr) { + MS_LOG(ERROR) << "create session failed!"; + return RET_ERROR; + } + + auto ret = session_->CompileGraph(model.get()); + if (ret != lite::RET_OK) { + MS_LOG(ERROR) << "compile graph error"; + return RET_ERROR; + } + + MS_LOG(INFO) << "start to update divergence's max value"; + status = DoInference(); + if (status != RET_OK) { + return status; + } + MS_LOG(INFO) << "start to update divergence's interval"; + status = UpdateDivergInverval(); + if (status != RET_OK) { + return status; + } + MS_LOG(INFO) << "start to collect data's distribution"; + status = CollectDataFrequency(); + if (status != RET_OK) { + return status; + } + MS_LOG(INFO) << "compute the best threshold"; + status = ComputeThreshold(); + if (status != RET_OK) { + return status; + } + MS_LOG(INFO) << "start to generate quant param and quantize tensor's data"; + status = QuantNode(); + if (status != RET_OK) { + return status; + } + return RET_OK; +} +} // namespace quant +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/quantizer/post_training.h b/mindspore/lite/tools/converter/quantizer/post_training.h new file mode 100644 index 00000000000..3d83c4c923d --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/post_training.h @@ -0,0 +1,159 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef POSTRAINING_QUANTIZER_H +#define POSTRAINING_QUANTIZER_H + +#include +#include +#include +#include +#include +#include +#include "src/lite_session.h" +#include "tools/converter/quantizer/quantizer.h" +#include "src/ir/primitive_t_value.h" +#include "tools/converter/converter.h" + +namespace mindspore { +namespace lite { +namespace quant { +class Calibrator; + +struct MaxMin { + public: + float min; + float max; +}; + +enum ImageFormat { + RGB = 0, + GRAY = 1, + BGR = 2, +}; + +struct ConfigParam { + // ImageFormat imageFormat; + std::string image_path; + uint32_t batch_count; + uint32_t thread_num; +}; + +class PostTrainingQuantizer : public Quantizer { + public: + PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8); + + STATUS DoQuantize(FuncGraphPtr funcGraph) override; + + size_t bit_num; + int quant_max{255}; + int quant_min{0}; + + private: + TypeId target_type_{kNumberTypeUInt8}; + + std::unique_ptr calibrator_; + + mindspore::lite::LiteSession *session_; + + STATUS PreProcess(); + + STATUS CheckTensorVec(const std::string &nodeName, const std::vector &tensorVec) const; + + STATUS DoInference(); + + STATUS UpdateDivergInverval(); + + STATUS CollectDataFrequency(); + + STATUS ComputeThreshold(); + + STATUS QuantNode(); + + // STATUS reformatConvWeight(GraphDefT *graph); + + STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr); + STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr); + + STATUS DoWeightQuant(AnfNodePtr node); + + STATUS DoBiasQuant(std::shared_ptr input, AnfNodePtr weight, AnfNodePtr bias); +}; + +struct DivergInfo; + +class Calibrator { + public: + explicit Calibrator(std::string path, size_t quant_size, int quant_max, int quant_msin); + + ~Calibrator() = default; + + STATUS ReadConfig(); + + STATUS CollectImages(); + + STATUS GenerateInputData(int index, mindspore::tensor::MSTensor *tensor) const; + + size_t GetBatchNum() const { return images_.size(); } + + uint32_t GetThreadNum() const { return config_param_.thread_num; } + + STATUS AddQuantizedOp(CNodePtr node); + + STATUS RecordMaxValue(std::string opName, std::vector data, + std::unordered_map> *diverg_info); + + STATUS UpdateDivergInverval(std::unordered_map> *diverg_info); + + STATUS UpdateDataFrequency(std::string op_name, std::vector data, std::vector shape, + std::unordered_map> *diverg_info); + void Dump(); + + STATUS ComputeThreshold(); + + std::unordered_map GetResult( + std::unordered_map> *diverg_info); + + std::unordered_map GetZeropoint( + std::unordered_map> *diverg_info); + + std::map GetMinMax(std::unordered_map> *diverg_info); + + std::unordered_map> *GetInputDivergInfo(); + + std::unordered_map> *GetOutputDivergInfo(); + + private: + std::vector images_; + + std::string config_path_; + + ConfigParam config_param_; + + std::unordered_map> input_diverg_info_; + + std::unordered_map> output_diverg_info_; + + size_t bit_num_; + int quant_max_; + int quant_min_; + + void AddImage(std::string file); +}; +} // namespace quant +} // namespace lite +} // namespace mindspore +#endif // POSTRAINING_QUANTIZER_H diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc new file mode 100644 index 00000000000..64dd0c0c183 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -0,0 +1,343 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include "src/ir/primitive_t_value.h" +#include "mindspore/lite/tools/converter/quantizer/quantize_util.h" +#include "mindspore/lite/tools/converter/quantizer/general_bitpacking.h" +#include "src/common/utils.h" +#include "abstract/abstract_value.h" + +using std::string; +using std::vector; + +namespace mindspore { +namespace lite { +namespace quant { +const std::array QuantStrategy::mConvTypes = { + {"Conv2D", "DeConv2D", "DepthwiseConv2D", "DeDepthwiseConv2D"}}; +const std::array QuantStrategy::mMulTypes = {{"Mul", "MatMul", "BatchMatMul", "FullConnection"}}; + +QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThreshold) + : mWeightSize(weightSize), mConvWeightQuantChannelThreshold(convWeightQuantChannelThreshold) {} + +bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const { + size_t i = 0; + for (i = 0; i < mConvTypes.size(); i++) { + if (node->fullname_with_scope().find(mConvTypes[i]) == 0) { + break; + } + } + + if ((i == mConvTypes.size()) || (node->size() < 3)) { + return false; + } + + auto inputNode = node->input(2); + if (!inputNode->isa()) { + return false; + } + auto paramNode = inputNode->cast(); + auto abstract_base = paramNode->abstract(); + if (abstract_base == nullptr) { + return false; + } + + if (!utils::isa(abstract_base->GetShapeTrack())) { + MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name(); + return false; + } + auto weight_shape = utils::cast(abstract_base->GetShapeTrack())->shape(); + size_t shapeSize = 1; + for (auto dim : weight_shape) { + shapeSize = shapeSize * dim; + } + if (shapeSize < mWeightSize) { + MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize; + return false; + } + if (weight_shape[0] <= mConvWeightQuantChannelThreshold) { + MS_LOG(INFO) << "channel less mConvWeightQuantChannelThreshold!" << weight_shape[0]; + return false; + } + + return true; +} + +bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { + if (!node->isa()) { + return false; + } + auto cnode = std::dynamic_pointer_cast(node); + + auto primitiveT_value = GetValueNode>(cnode->input(0)); + if (primitiveT_value == nullptr) { + MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; + return false; + } + + auto type = primitiveT_value->GetPrimitiveT()->value.type; + MS_LOG(INFO) << "Primitive type: " << type; + static const std::vector uint8OpList = { + schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, + schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, + schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax, schema::PrimitiveType_Reshape, + schema::PrimitiveType_Activation}; + return IsContain(uint8OpList, type); +} + +bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const { + size_t i = 0; + for (i = 0; i < mMulTypes.size(); i++) { + if (node->fullname_with_scope().find(mMulTypes[i]) == 0) { + break; + } + } + if (i == mMulTypes.size()) { + return false; + } + + if (node->size() < 3) { + MS_LOG(INFO) << "input size less!"; + return false; + } + + auto inputNode1 = node->input(1); + auto inputNode2 = node->input(2); + if (inputNode1 == nullptr || inputNode2 == nullptr) { + MS_LOG(INFO) << "mul input is nullptr!"; + return false; + } + + ParameterPtr paramNode = nullptr; + if (inputNode1->isa()) { + paramNode = inputNode1->cast(); + } else if (inputNode2->isa()) { + paramNode = inputNode2->cast(); + } + + if (paramNode == nullptr) { + MS_LOG(INFO) << "invalid paramNode!"; + return false; + } + + auto abstract_base = paramNode->abstract(); + if (abstract_base == nullptr) { + MS_LOG(INFO) << "abstract is nullptr"; + return false; + } + + if (!utils::isa(abstract_base->GetShapeTrack())) { + MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name(); + return false; + } + auto weight_shape = utils::cast(abstract_base->GetShapeTrack())->shape(); + size_t shapeSize = 1; + for (auto dim : weight_shape) { + shapeSize = shapeSize * dim; + } + if (shapeSize < mWeightSize) { + MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize; + return false; + } + + return true; +} + +void CalFakeNode(const AnfNodePtr &inTensor) { + // MS_ASSERT(inTensor != nullptr); + // MS_ASSERT(inTensor->dataType == DataType_DT_FLOAT); + // auto quantParam = GetTensorQuantParams(inTensor); + // if (quantParam == nullptr || !quantParam->inited) { + // MS_LOGW("tensor quantParam has not been inited"); + // return; + // } + + // float quantMin = quantParam->narrowRange ? 1 : 0; + // float quantMax = (1 << (unsigned int)(quantParam->numBits)) - 1; + // const float scale = quantParam->scale; + // const float nudgedMin = (quantMin - quantParam->zeroPoint) * scale; + // const float nudgedMax = (quantMax - quantParam->zeroPoint) * scale; + // // cal output + // float invNudgeScale = 1.0f / scale; + // void *inData = inTensor->data.data(); + // if(inData == nullptr) { + // MS_LOGE("null pointer dereferencing."); + // return; + // } + // auto *data = static_cast(inData); + // for (size_t i = 0; i < GetShapeSize(*inTensor); i++) { + // float clamped = std::min(nudgedMax, std::max(nudgedMin, data[i])); + // float clampedShifted = clamped - nudgedMin; + // data[i] = std::round(clampedShifted * invNudgeScale) * scale + nudgedMin; + // } +} + +STATUS CalQuantizationParams(std::unique_ptr &quantParam, double mMin, + double mMax, bool narrowRange, int numBits) { + MS_ASSERT(quantParam != nullptr); + if (mMin > 0.0f) { + MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision"; + mMin = 0.0f; + } + if (mMax < 0.0f) { + MS_LOG(ERROR) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision"; + mMax = 0.0f; + } + if (mMin > mMax) { + MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax; + return RET_PARAM_INVALID; + } + if (mMin == mMax) { + if (mMin != 0.0f) { + MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other"; + return RET_ERROR; + } + quantParam->inited = true; + quantParam->min = mMin; + quantParam->max = mMax; + quantParam->scale = 0.0f; + quantParam->zeroPoint = 0; + quantParam->narrowRange = narrowRange; + quantParam->numBits = numBits; + return RET_OK; + } + + int quantMin = narrowRange ? 1 : 0; + int quantMax = (1 << (unsigned int)numBits) - 1; + auto quantMinFloat = static_cast(quantMin); + auto quantMaxFloat = static_cast(quantMax); + double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat); + const double zeroPointFromMin = quantMinFloat - mMin / scale; + const double zeroPointFromMax = quantMaxFloat - mMax / scale; + const double zpFromMinError = std::abs(quantMinFloat) + std::abs(mMin / scale); + const double zpFromMaxError = std::abs(quantMaxFloat) + std::abs(mMax / scale); + const double zpDouble = zpFromMinError < zpFromMaxError ? zeroPointFromMin : zeroPointFromMax; + int zeroPoint; + if (zpDouble < quantMinFloat) { + zeroPoint = quantMin; + } else if (zpDouble > quantMaxFloat) { + zeroPoint = quantMax; + } else { + zeroPoint = static_cast(std::round(zpDouble)); + } + // The zero point should always be in the range of quantized value, + // [qmin, qmax]. + MS_ASSERT(zeroPoint >= quantMin); + MS_ASSERT(zeroPoint <= quantMax); + quantParam->inited = true; + quantParam->min = mMin; + quantParam->max = mMax; + quantParam->scale = scale; + quantParam->zeroPoint = zeroPoint; + quantParam->narrowRange = narrowRange; + quantParam->numBits = numBits; + + return RET_OK; +} + +STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, size_t bitNum) { + auto dims = weightPtr->tensor_shape(); + if (dims.size() < 1) { + MS_LOG(ERROR) << "weight dims size error"; + return RET_ERROR; + } + uint32_t channels = dims[0]; + if (channels == 0) { + MS_LOG(ERROR) << "channels error 0"; + return RET_ERROR; + } + + size_t shapeSize = weightPtr->tensor_shape_size(); + size_t oneFilterSize = shapeSize / channels; + auto *rawDatas = reinterpret_cast(weightPtr->tensor_addr()); + if (rawDatas == nullptr) { + MS_LOG(ERROR) << "rawDatas is nullptr"; + return RET_ERROR; + } + + weightPtr->quant_param().clear(); + vector qDatas(shapeSize); + for (uint32_t i = 0; i < channels; i++) { + float min = 0; + float max = 0; + // find min and max + for (uint32_t j = 0; j < oneFilterSize; j++) { + min = std::min(min, rawDatas[j + i * oneFilterSize]); + max = std::max(max, rawDatas[j + i * oneFilterSize]); + } + + std::unique_ptr quantParam = std::unique_ptr(new AnfQuantParam); + STATUS status = CalQuantizationParams(quantParam, min, max, false, bitNum); + if (status != RET_OK) { + MS_LOG(ERROR) << "CalQuantizationParams failed" << status; + return status; + } + // update data and datatype + for (uint32_t j = 0; j < oneFilterSize; j++) { + float rawData = rawDatas[j + i * oneFilterSize]; + auto qData = QuantizeData(rawData, quantParam.get()); + qDatas[j + i * oneFilterSize] = qData; + } + + weightPtr->set_quant_param(quantParam); + } + auto ret = memcpy_s(const_cast(rawDatas), weightPtr->tensor_size(), + qDatas.data(), shapeSize * sizeof(uint8_t)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy error: " << ret; + return RET_ERROR; + } + if (quantType == QuantType_WeightQuant) { + PostBitPack(const_cast(rawDatas), shapeSize, bitNum); + } + + weightPtr->set_tensor_type(kNumberTypeUInt8); + weightPtr->set_tensor_size(shapeSize * sizeof(uint8_t)); + + return RET_OK; +} + +STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) { + auto *rawDatas = reinterpret_cast(weight); + vector qDatas(rawDatas, rawDatas + shapeSize); + vector qDatas_packed; + if (bitNum < 8 && bitNum > 1) { + BitPack weight_bitpack(bitNum); + weight_bitpack.BitPacking(qDatas, qDatas_packed); + if (0 != memcpy_s(rawDatas, shapeSize, &qDatas_packed[0], shapeSize)) { + MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed"; + return RET_ERROR; + } + } else if (bitNum == 8) { + if (0 != memcpy_s(rawDatas, shapeSize, &qDatas[0], shapeSize)) { + MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas failed"; + return RET_ERROR; + } + } else { + MS_LOG(ERROR) << "bitNum must be between 0 and 8 : " << bitNum; + return RET_ERROR; + } + + return RET_OK; +} +} // namespace quant +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h new file mode 100644 index 00000000000..12cd2b4691d --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef QUANTIZER_UTIL_H +#define QUANTIZER_UTIL_H + +#include +#include +#include +#include +#include "include/errorcode.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "include/model.h" +#include "base/base.h" +#include "ir/primitive.h" +#include "abstract/dshape.h" +#include "mindspore/lite/tools/converter/quantizer/quantizer.h" + +namespace mindspore { +namespace lite { +namespace quant { + +static constexpr size_t UINT8_QUANTIZATION = 8; + +/** + * 1. when op's weight size > mWeightSize just skip + * 2. only do conv/deconv/convdepthwise/deconvdepthwise/mul/matmul/batchmatmul quantization + * 3. when conv/deconv/convdepthwise/deconvdepthwise ops' weight channel size > covWeightQuantChannelThreshold just skip + * */ +class QuantStrategy { + public: + explicit QuantStrategy(size_t weightSize, size_t covWeightQuantChannelThreshold = 16); + + ~QuantStrategy() = default; + + bool CanConvOpQuantized(const CNodePtr &node) const; + bool CanMulOpQuantized(const CNodePtr &node) const; + bool CanOpPostQuantized(AnfNodePtr &node) const; + + private: + size_t mWeightSize; + size_t mConvWeightQuantChannelThreshold; + + static const std::array mConvTypes; + static const std::array mMulTypes; +}; + +STATUS CalQuantizationParams(std::unique_ptr &quantParam, double mMin, double mMax, + bool narrowRange = false, int numBits = UINT8_QUANTIZATION); + +template +T QuantizeData(const float originData, const AnfQuantParam *quantParam) { + MS_ASSERT(quantParam != nullptr); + MS_ASSERT(quantParam->inited); + const auto scale = quantParam->scale; + const auto zeroPoint = quantParam->zeroPoint; + const auto numBit = quantParam->numBits; + const auto narrowRange = quantParam->narrowRange; + const double maxLimit = static_cast((1 << (unsigned int)numBit) - 1 - zeroPoint) * scale; + double minLimit; + if (narrowRange) { + minLimit = static_cast(1 - zeroPoint) * scale; + } else { + minLimit = static_cast(0 - zeroPoint) * scale; + } + return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] { + double tmp = 0.0f; + if (originData > maxLimit) { + tmp = maxLimit; + } else if (originData < minLimit) { + tmp = minLimit; + } else { + tmp = originData; + } + auto quantData = static_cast(std::round(tmp / scale + zeroPoint)); + if (quantData == 0 && narrowRange) { + quantData++; + } + return quantData; + }(); +} + +void CalFakeNode(const AnfNodePtr &inTensor); + +STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType = QuantType_AwareTraining, + size_t bitNum = UINT8_QUANTIZATION); + +STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION); + +} // namespace quant +} // namespace lite +} // namespace mindspore +#endif diff --git a/mindspore/lite/tools/converter/quantizer/quantizer.cc b/mindspore/lite/tools/converter/quantizer/quantizer.cc new file mode 100644 index 00000000000..3480705c62d --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quantizer.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/lite/tools/converter/quantizer/quantizer.h" + +namespace mindspore { +namespace lite { +namespace quant { +Quantizer::Quantizer(FuncGraphPtr graph) : funcGraph(graph) { + if (funcGraph == nullptr) { + return; + } +} + +STATUS Quantizer::GenerateQuantParam() { return RET_OK; } + +STATUS Quantizer::RemoveFakeQuant() { return RET_OK; } + +STATUS Quantizer::DetermineNodeQuantType() { return RET_OK; } +} // namespace quant +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/quantizer/quantizer.h b/mindspore/lite/tools/converter/quantizer/quantizer.h new file mode 100644 index 00000000000..19284052f3a --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quantizer.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_QUANTIZER_H +#define MS_QUANTIZER_H + +#include +#include "include/errorcode.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "include/model.h" +#include "base/base.h" +#include "src/param_value_lite.h" + +namespace mindspore { +namespace lite { +namespace quant { +using STATUS = int; +enum QuantType { + QuantType_QUANT_NONE = 0, + QuantType_AwareTraining = 1, + QuantType_WeightQuant = 2, + QuantType_PostTraining = 3, + QuantType_MIN = QuantType_QUANT_NONE, + QuantType_MAX = QuantType_PostTraining +}; + +class Quantizer { + public: + explicit Quantizer(FuncGraphPtr graph); + + ~Quantizer() = default; + + virtual STATUS RemoveFakeQuant(); + + virtual STATUS GenerateQuantParam(); + + virtual STATUS DetermineNodeQuantType(); + + virtual STATUS DoQuantize(FuncGraphPtr funcGraph) = 0; + + protected: + FuncGraphPtr funcGraph = nullptr; +}; +} // namespace quant +} // namespace lite +} // namespace mindspore + +#endif + diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc new file mode 100644 index 00000000000..26846f80146 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -0,0 +1,151 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/quantizer/weight_quantizer.h" +#include +#include +#include "src/common/common.h" +#include "ir/dtype/type_id.h" + +using std::string; +using std::vector; + +namespace mindspore { +namespace lite { +namespace quant { + +WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize, + const std::string &convWeightChannelThreshold, const std::string &bitNum) + : Quantizer(graph) { + auto quantSize = static_cast(std::stoull(weightSize)); + this->bitNum = static_cast(std::stoull(bitNum)); + auto convQuantWeightChannelThreshold = static_cast(std::stoull(convWeightChannelThreshold)); + // TODO(...): update stractory + mStrategy.reset(new QuantStrategy(quantSize, convQuantWeightChannelThreshold)); +} + +// uint32_t GetConvChannel(TensorDefT *weight) { +// uint32_t channel = 0; +// const vector dims = weight->dims; + +// switch (weight->format) { +// case Format_NCHW: +// case Format_KCHW: +// case Format_NC4HW4: +// channel = static_cast(dims[NCHW_N]); +// break; +// case Format_NHWC: +// case Format_HWKC: +// channel = static_cast(dims[NHWC_N]); +// break; +// case Format_HWCK: +// channel = static_cast(dims[HWCK_K]); +// break; +// case Format_CKHW: +// channel = static_cast(dims[CKHW_K]); +// break; +// default: +// MS_LOGE("Unsupported format: %d", weight->format); +// return 0; +// } +// return channel; +// } + +STATUS WeightQuantizer::DoConvQuantize(const std::list &nodes) { + for (auto &cnode : nodes) { + if (!mStrategy->CanConvOpQuantized(cnode)) { + continue; + } + + auto inputNode = cnode->input(2); + if (!inputNode->isa()) { + return RET_ERROR; + } + + auto paramNode = inputNode->cast(); + if (!paramNode->has_default()) { + return RET_ERROR; + } + + ParamValueLitePtr paramValue = std::static_pointer_cast(paramNode->default_param()); + auto status = QuantFilter(paramValue, QuantType_WeightQuant, bitNum); + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantFilter failed : " << status; + return status; + } + } + + return RET_OK; +} + +STATUS WeightQuantizer::DoMulQuantize(const std::list &nodes) { + for (auto &node : nodes) { + if (!mStrategy->CanMulOpQuantized(node)) { + continue; + } + + ParamValueLitePtr paramValue = nullptr; + for (size_t i = 1; i < node->size(); i++) { + auto inputNode = node->input(i); + if (inputNode->isa() == true) { + auto paramNode = inputNode->cast(); + if ((paramNode != nullptr) && (paramNode->has_default() == true)) { + paramValue = std::static_pointer_cast(paramNode->default_param()); + if ((paramValue == nullptr) || (paramValue->tensor_size() == 0) + || (paramValue->tensor_shape().size() != 4) + || (paramValue->tensor_addr() == nullptr) + || (paramValue->tensor_type() != mindspore::kNumberTypeFloat32)) { + paramValue = nullptr; + continue; + } else { + break; + } + } + } + } + if (paramValue == nullptr) { + MS_LOG(ERROR) << "No valid input param node !"; + continue; + } + auto status = QuantFilter(paramValue, QuantType_WeightQuant, bitNum); + if (status != RET_OK) { + MS_LOG(ERROR) << "QunatFilter failed" << status; + return RET_ERROR; + } + } + + return RET_OK; +} + +STATUS WeightQuantizer::DoQuantize(FuncGraphPtr funcGraph) { + auto ret = RET_OK; + auto cnodes = funcGraph->GetOrderedCnodes(); + ret = DoConvQuantize(cnodes); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DoConvQuantize failed :" << ret; + return ret; + } + ret = DoMulQuantize(cnodes); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DoMulQuantize failed :" << ret; + return ret; + } + return ret; +} +} // namespace quant +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h new file mode 100644 index 00000000000..0726dd3df1e --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef WEIGHT_QUANTIZER_H +#define WEIGHT_QUANTIZER_H + +#include +#include +#include +#include "tools/converter/quantizer/quantizer.h" +#include "tools/converter/quantizer/quantize_util.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "include/model.h" +#include "base/base.h" +#include "abstract/dshape.h" + +namespace mindspore { +namespace lite { +namespace quant { +class WeightQuantizer : public Quantizer { + public: + WeightQuantizer(FuncGraphPtr graph, const std::string& weightSize, + const std::string& covWeightChannelThreshold, const std::string& bitNum); + + ~WeightQuantizer() = default; + + STATUS DoQuantize(FuncGraphPtr funcGraph) override; + STATUS DoConvQuantize(const std::list &nodes); + STATUS DoMulQuantize(const std::list &nodes); + + private: + std::unique_ptr mStrategy; + size_t bitNum; +}; +} // namespace quant +} // namespace lite +} // namespace mindspore +#endif + diff --git a/tests/ut/cpp/common/common_test.h b/tests/ut/cpp/common/common_test.h index a293584d7b9..8490046f13a 100644 --- a/tests/ut/cpp/common/common_test.h +++ b/tests/ut/cpp/common/common_test.h @@ -16,6 +16,9 @@ #ifndef TESTS_UT_COMMON_UT_COMMON_H_ #define TESTS_UT_COMMON_UT_COMMON_H_ +#include +#include +#include #include "gtest/gtest.h" namespace UT { class Common : public testing::Test { @@ -27,6 +30,47 @@ class Common : public testing::Test { // every TEST_F macro will enter one virtual void SetUp(); virtual void TearDown(); + + template + void PrintData(std::string name, T *output_data, int size) { + std::cout << "The " << name << " is as follows:" << std::endl; + if (typeid(output_data[0]) == typeid(uint8_t) || typeid(output_data[0]) == typeid(int8_t)) { + for (size_t i = 0; i < std::min(size, 100); i++) { + std::cout << (int)output_data[i] << " "; + } + } else { + for (size_t i = 0; i < std::min(size, 100); i++) { + std::cout << output_data[i] << " "; + } + } + std::cout << std::endl; + } + + template + static void CompareOutputData(T *output_data, T *correct_data, int size, float err_bound) { + for (size_t i = 0; i < size; i++) { + T abs = fabs(output_data[i] - correct_data[i]); + ASSERT_LE(abs, err_bound); + } + } + + void ReadFile(const char *file, size_t *size, char **buf) { + ASSERT_NE(nullptr, file); + ASSERT_NE(nullptr, size); + ASSERT_NE(nullptr, buf); + std::string path = std::string(file); + std::ifstream ifs(path); + ASSERT_EQ(true, ifs.good()); + ASSERT_EQ(true, ifs.is_open()); + + ifs.seekg(0, std::ios::end); + *size = ifs.tellg(); + *buf = new char[*size]; + + ifs.seekg(0, std::ios::beg); + ifs.read(*buf, *size); + ifs.close(); + } }; } // namespace UT #endif // TESTS_UT_COMMON_UT_COMMON_H_