From 7bc95d8cd4faf4876ffad5c202a7495d94f84ad6 Mon Sep 17 00:00:00 2001 From: yeyunpeng Date: Fri, 21 Aug 2020 15:54:25 +0800 Subject: [PATCH] change primitive --- mindspore/lite/include/model.h | 2 +- mindspore/lite/src/CMakeLists.txt | 5 +- mindspore/lite/src/ir/primitive_t_value.cc | 40 ----- mindspore/lite/src/ir/primitive_t_value.h | 91 ----------- mindspore/lite/src/lite_kernel.h | 2 +- mindspore/lite/src/lite_session.cc | 3 +- mindspore/lite/src/lite_session.h | 2 +- mindspore/lite/src/model.cc | 3 +- mindspore/lite/src/ops/CMakeLists.txt | 3 - mindspore/lite/src/ops/abs.h | 5 +- mindspore/lite/src/ops/activation.cc | 12 +- mindspore/lite/src/ops/activation.h | 15 +- mindspore/lite/src/ops/activation_grad.cc | 6 +- mindspore/lite/src/ops/activation_grad.h | 16 +- mindspore/lite/src/ops/add.cc | 6 +- mindspore/lite/src/ops/add.h | 10 +- mindspore/lite/src/ops/addn.cc | 8 +- mindspore/lite/src/ops/addn.h | 16 +- mindspore/lite/src/ops/argmax.cc | 32 ++-- mindspore/lite/src/ops/argmax.h | 16 +- mindspore/lite/src/ops/argmin.cc | 32 ++-- mindspore/lite/src/ops/argmin.h | 16 +- mindspore/lite/src/ops/arithmetic.cc | 2 +- mindspore/lite/src/ops/arithmetic.h | 16 +- mindspore/lite/src/ops/arithmetic_self.cc | 2 +- mindspore/lite/src/ops/arithmetic_self.h | 18 +-- mindspore/lite/src/ops/batch_norm.cc | 6 +- mindspore/lite/src/ops/batch_norm.h | 16 +- mindspore/lite/src/ops/batch_to_space.cc | 14 +- mindspore/lite/src/ops/batch_to_space.h | 16 +- mindspore/lite/src/ops/bias_add.cc | 6 +- mindspore/lite/src/ops/bias_add.h | 16 +- mindspore/lite/src/ops/bias_grad.cc | 6 +- mindspore/lite/src/ops/bias_grad.h | 16 +- mindspore/lite/src/ops/bn_grad_input.cc | 12 +- mindspore/lite/src/ops/bn_grad_input.h | 16 +- mindspore/lite/src/ops/broadcast_to.cc | 6 +- mindspore/lite/src/ops/broadcast_to.h | 16 +- mindspore/lite/src/ops/caffe_p_relu.cc | 6 +- mindspore/lite/src/ops/caffe_p_relu.h | 17 +- mindspore/lite/src/ops/cast.cc | 14 +- mindspore/lite/src/ops/cast.h | 16 +- mindspore/lite/src/ops/ceil.h | 18 +-- mindspore/lite/src/ops/clip.cc | 12 +- mindspore/lite/src/ops/clip.h | 16 +- mindspore/lite/src/ops/concat.cc | 14 +- mindspore/lite/src/ops/concat.h | 16 +- mindspore/lite/src/ops/constant_of_shape.cc | 6 +- mindspore/lite/src/ops/constant_of_shape.h | 15 +- mindspore/lite/src/ops/conv2d.cc | 104 ++++++------ mindspore/lite/src/ops/conv2d.h | 16 +- mindspore/lite/src/ops/conv2d_grad_filter.cc | 104 ++++++------ mindspore/lite/src/ops/conv2d_grad_filter.h | 16 +- mindspore/lite/src/ops/conv2d_grad_input.cc | 104 ++++++------ mindspore/lite/src/ops/conv2d_grad_input.h | 16 +- mindspore/lite/src/ops/cos.h | 18 +-- mindspore/lite/src/ops/crop.cc | 12 +- mindspore/lite/src/ops/crop.h | 16 +- mindspore/lite/src/ops/deconv2d.cc | 104 ++++++------ mindspore/lite/src/ops/deconv2d.h | 16 +- mindspore/lite/src/ops/dedepthwise_conv2d.cc | 98 ++++++------ mindspore/lite/src/ops/dedepthwise_conv2d.h | 16 +- mindspore/lite/src/ops/depth_to_space.cc | 14 +- mindspore/lite/src/ops/depth_to_space.h | 16 +- mindspore/lite/src/ops/depthwise_conv2d.cc | 100 ++++++------ mindspore/lite/src/ops/depthwise_conv2d.h | 16 +- .../lite/src/ops/detection_post_process.cc | 82 +++++----- .../lite/src/ops/detection_post_process.h | 16 +- mindspore/lite/src/ops/div.cc | 6 +- mindspore/lite/src/ops/div.h | 17 +- mindspore/lite/src/ops/dropout.cc | 6 +- mindspore/lite/src/ops/dropout.h | 16 +- mindspore/lite/src/ops/eltwise.cc | 6 +- mindspore/lite/src/ops/eltwise.h | 16 +- mindspore/lite/src/ops/elu.cc | 6 +- mindspore/lite/src/ops/elu.h | 16 +- mindspore/lite/src/ops/embedding_lookup.cc | 8 +- mindspore/lite/src/ops/embedding_lookup.h | 16 +- .../lite/src/ops/embedding_lookup_sparse.cc | 18 +-- .../lite/src/ops/embedding_lookup_sparse.h | 16 +- mindspore/lite/src/ops/equal.h | 16 +- mindspore/lite/src/ops/exp.h | 16 +- mindspore/lite/src/ops/expand_dims.cc | 8 +- mindspore/lite/src/ops/expand_dims.h | 16 +- .../src/ops/fake_quant_with_min_max_vars.cc | 12 +- .../src/ops/fake_quant_with_min_max_vars.h | 16 +- mindspore/lite/src/ops/fill.cc | 8 +- mindspore/lite/src/ops/fill.h | 16 +- mindspore/lite/src/ops/flatten.cc | 2 +- mindspore/lite/src/ops/flatten.h | 16 +- mindspore/lite/src/ops/floor.h | 18 +-- mindspore/lite/src/ops/floor_div.h | 16 +- mindspore/lite/src/ops/floor_mod.h | 16 +- mindspore/lite/src/ops/full_connection.cc | 26 +-- mindspore/lite/src/ops/full_connection.h | 16 +- mindspore/lite/src/ops/fused_batchnorm.cc | 18 +-- mindspore/lite/src/ops/fused_batchnorm.h | 16 +- mindspore/lite/src/ops/gather.cc | 14 +- mindspore/lite/src/ops/gather.h | 16 +- mindspore/lite/src/ops/gather_nd.cc | 8 +- mindspore/lite/src/ops/gather_nd.h | 16 +- mindspore/lite/src/ops/greater.h | 15 +- mindspore/lite/src/ops/greater_equal.h | 16 +- mindspore/lite/src/ops/l2_norm.cc | 12 +- mindspore/lite/src/ops/l2_norm.h | 16 +- mindspore/lite/src/ops/leaky_relu.cc | 6 +- mindspore/lite/src/ops/leaky_relu.h | 16 +- mindspore/lite/src/ops/less.h | 16 +- mindspore/lite/src/ops/less_equal.h | 16 +- .../src/ops/local_response_normalization.cc | 24 +-- .../src/ops/local_response_normalization.h | 16 +- mindspore/lite/src/ops/log.h | 16 +- mindspore/lite/src/ops/logical_and.h | 18 +-- mindspore/lite/src/ops/logical_not.h | 18 +-- mindspore/lite/src/ops/logical_or.h | 18 +-- mindspore/lite/src/ops/lrn.cc | 24 +-- mindspore/lite/src/ops/lrn.h | 16 +- mindspore/lite/src/ops/lstm.cc | 8 +- mindspore/lite/src/ops/lstm.h | 16 +- mindspore/lite/src/ops/matmul.cc | 14 +- mindspore/lite/src/ops/matmul.h | 16 +- mindspore/lite/src/ops/matrix_diag.cc | 24 +-- mindspore/lite/src/ops/matrix_diag.h | 16 +- mindspore/lite/src/ops/maximum.h | 18 +-- mindspore/lite/src/ops/mean.cc | 14 +- mindspore/lite/src/ops/mean.h | 16 +- mindspore/lite/src/ops/minimum.h | 18 +-- mindspore/lite/src/ops/mul.cc | 6 +- mindspore/lite/src/ops/mul.h | 17 +- mindspore/lite/src/ops/nchw2nhwc.cc | 2 +- mindspore/lite/src/ops/nchw2nhwc.h | 16 +- mindspore/lite/src/ops/nhwc2nchw.cc | 2 +- mindspore/lite/src/ops/nhwc2nchw.h | 16 +- mindspore/lite/src/ops/not_equal.h | 18 +-- mindspore/lite/src/ops/one_hot.cc | 8 +- mindspore/lite/src/ops/one_hot.h | 16 +- mindspore/lite/src/ops/pad.cc | 22 +-- mindspore/lite/src/ops/pad.h | 16 +- mindspore/lite/src/ops/permute.cc | 6 +- mindspore/lite/src/ops/permute.h | 16 +- mindspore/lite/src/ops/pooling.cc | 82 +++++----- mindspore/lite/src/ops/pooling.h | 16 +- mindspore/lite/src/ops/pooling_grad.cc | 78 ++++----- mindspore/lite/src/ops/pooling_grad.h | 16 +- mindspore/lite/src/ops/power.cc | 20 +-- mindspore/lite/src/ops/power.h | 16 +- mindspore/lite/src/ops/power_grad.cc | 18 +-- mindspore/lite/src/ops/power_grad.h | 16 +- mindspore/lite/src/ops/prelu.cc | 6 +- mindspore/lite/src/ops/prelu.h | 18 +-- mindspore/lite/src/ops/primitive_c.cc | 93 ++++++++--- mindspore/lite/src/ops/primitive_c.h | 106 +++++++++++-- mindspore/lite/src/ops/prior_box.cc | 66 ++++---- mindspore/lite/src/ops/prior_box.h | 16 +- mindspore/lite/src/ops/quant_dtype_cast.cc | 14 +- mindspore/lite/src/ops/quant_dtype_cast.h | 16 +- mindspore/lite/src/ops/range.cc | 26 +-- mindspore/lite/src/ops/range.h | 16 +- mindspore/lite/src/ops/rank.cc | 2 +- mindspore/lite/src/ops/rank.h | 16 +- mindspore/lite/src/ops/reduce.cc | 20 +-- mindspore/lite/src/ops/reduce.h | 16 +- mindspore/lite/src/ops/reshape.cc | 31 ++-- mindspore/lite/src/ops/reshape.h | 16 +- mindspore/lite/src/ops/resize.cc | 38 ++--- mindspore/lite/src/ops/resize.h | 16 +- mindspore/lite/src/ops/reverse.cc | 6 +- mindspore/lite/src/ops/reverse.h | 16 +- mindspore/lite/src/ops/reverse_sequence.cc | 18 +-- mindspore/lite/src/ops/reverse_sequence.h | 16 +- mindspore/lite/src/ops/roi_pooling.cc | 20 +-- mindspore/lite/src/ops/roi_pooling.h | 16 +- mindspore/lite/src/ops/round.h | 18 +-- mindspore/lite/src/ops/rsqrt.h | 18 +-- mindspore/lite/src/ops/scale.cc | 6 +- mindspore/lite/src/ops/scale.h | 16 +- mindspore/lite/src/ops/scatter_nd.h | 16 +- mindspore/lite/src/ops/shape.h | 16 +- mindspore/lite/src/ops/sin.h | 18 +-- mindspore/lite/src/ops/slice.cc | 20 +-- mindspore/lite/src/ops/slice.h | 16 +- mindspore/lite/src/ops/softmax.cc | 8 +- mindspore/lite/src/ops/softmax.h | 16 +- .../lite/src/ops/softmax_cross_entropy.cc | 6 +- .../lite/src/ops/softmax_cross_entropy.h | 16 +- mindspore/lite/src/ops/space_to_batch.cc | 14 +- mindspore/lite/src/ops/space_to_batch.h | 16 +- mindspore/lite/src/ops/space_to_batch_nd.cc | 14 +- mindspore/lite/src/ops/space_to_batch_nd.h | 16 +- mindspore/lite/src/ops/space_to_depth.cc | 14 +- mindspore/lite/src/ops/space_to_depth.h | 16 +- mindspore/lite/src/ops/sparse_to_dense.cc | 30 ++-- mindspore/lite/src/ops/sparse_to_dense.h | 16 +- mindspore/lite/src/ops/split.cc | 20 +-- mindspore/lite/src/ops/split.h | 16 +- mindspore/lite/src/ops/sqrt.h | 18 +-- mindspore/lite/src/ops/square.h | 17 +- mindspore/lite/src/ops/squared_difference.h | 18 +-- mindspore/lite/src/ops/squeeze.cc | 8 +- mindspore/lite/src/ops/squeeze.h | 16 +- mindspore/lite/src/ops/stack.cc | 20 +-- mindspore/lite/src/ops/stack.h | 16 +- mindspore/lite/src/ops/strided_slice.cc | 56 +++---- mindspore/lite/src/ops/strided_slice.h | 16 +- mindspore/lite/src/ops/sub.cc | 6 +- mindspore/lite/src/ops/sub.h | 18 +-- mindspore/lite/src/ops/tile.cc | 14 +- mindspore/lite/src/ops/tile.h | 16 +- mindspore/lite/src/ops/topk.cc | 14 +- mindspore/lite/src/ops/topk.h | 16 +- mindspore/lite/src/ops/transpose.cc | 14 +- mindspore/lite/src/ops/transpose.h | 16 +- mindspore/lite/src/ops/unique.cc | 8 +- mindspore/lite/src/ops/unique.h | 16 +- mindspore/lite/src/ops/unsqueeze.cc | 8 +- mindspore/lite/src/ops/unsqueeze.h | 16 +- mindspore/lite/src/ops/unstack.cc | 12 +- mindspore/lite/src/ops/unstack.h | 16 +- mindspore/lite/src/ops/upsample.cc | 13 +- mindspore/lite/src/ops/upsample.h | 18 +-- mindspore/lite/src/ops/where.cc | 8 +- mindspore/lite/src/ops/where.h | 16 +- mindspore/lite/src/ops/zeros_like.cc | 2 +- mindspore/lite/src/ops/zeros_like.h | 16 +- mindspore/lite/src/populate_parameter.cc | 148 ++++++++++-------- .../runtime/kernel/arm/base/reduce_base.cc | 4 +- .../runtime/kernel/arm/base/resize_base.cc | 2 +- .../kernel/arm/fp32/arithmetic_self.cc | 2 +- .../lite/src/runtime/kernel/arm/fp32/pad.cc | 3 +- .../runtime/kernel/arm/int8/reduce_int8.cc | 25 +-- .../runtime/kernel/arm/int8/resize_int8.cc | 29 +++- mindspore/lite/src/scheduler.cc | 3 +- mindspore/lite/test/CMakeLists.txt | 3 +- .../runtime/kernel/arm/int8/add_int8_tests.cc | 2 +- .../arm/int8/arithmetic_self_int8_tests.cc | 1 + .../kernel/arm/int8/batchnorm_int8_test.cc | 1 + .../kernel/arm/int8/bias_add_int8_tests.cc | 1 + .../kernel/arm/int8/concat_int8_tests.cc | 1 + .../kernel/arm/int8/crop_int8_tests.cc | 1 + .../kernel/arm/int8/deconv_int8_tests.cc | 1 + .../runtime/kernel/arm/int8/div_int8_test.cc | 1 + .../arm/int8/fullconnection_int8_tests.cc | 1 + .../kernel/arm/int8/hswish_int8_tests.cc | 1 + .../kernel/arm/int8/matmul_int8_tests.cc | 2 + .../runtime/kernel/arm/int8/mul_int8_tests.cc | 1 + .../runtime/kernel/arm/int8/pad_int8_tests.cc | 1 + .../kernel/arm/int8/power_int8_tests.cc | 1 + .../kernel/arm/int8/prelu_int8_tests.cc | 1 + .../kernel/arm/int8/quant_dtype_cast_tests.cc | 1 + .../kernel/arm/int8/reduce_int8_tests.cc | 1 + .../kernel/arm/int8/relux_int8_tests.cc | 1 + .../kernel/arm/int8/reshape_int8_tests.cc | 1 + .../arm/int8/resize_bilinear_int8_tests.cc | 1 + .../resize_nearest_neighbor_int8_tests.cc | 1 + .../kernel/arm/int8/sigmoid_int8_tests.cc | 1 + .../kernel/arm/int8/slice_int8_tests.cc | 1 + .../kernel/arm/int8/softmax_int8_tests.cc | 1 + .../kernel/arm/int8/split_int8_tests.cc | 1 + .../kernel/arm/int8/squeeze_int8_tests.cc | 1 + .../runtime/kernel/arm/int8/sub_int_tests.cc | 1 + .../kernel/arm/int8/topk_int8_tests.cc | 1 + .../kernel/arm/int8/unsqueeze_int8_tests.cc | 1 + .../lite/tools/anf_exporter/anf_exporter.cc | 16 +- .../lite/tools/anf_exporter/anf_exporter.h | 5 +- .../anf_populater/anf_activation_populater.cc | 6 +- .../anf_populater/anf_activation_populater.h | 4 +- .../anf_populater/anf_batchnorm_populater.cc | 6 +- .../anf_populater/anf_batchnorm_populater.h | 3 +- .../anf_populater/anf_biasadd_populater.cc | 6 +- .../anf_populater/anf_biasadd_populater.h | 3 +- .../anf_populater/anf_concat_populater.cc | 6 +- .../anf_populater/anf_concat_populater.h | 3 +- .../anf_populater/anf_conv_populater.cc | 12 +- .../anf_populater/anf_conv_populater.h | 2 +- .../anf_depthwiseconv2d_populater.cc | 12 +- .../anf_depthwiseconv2d_populater.h | 2 +- .../anf_populater/anf_dequant_populater.cc | 6 +- .../anf_populater/anf_dequant_populater.h | 3 +- .../anf_populater/anf_flatten_populater.cc | 6 +- .../anf_populater/anf_flatten_populater.h | 3 +- .../anf_populater/anf_make_tuple_populater.cc | 6 +- .../anf_populater/anf_make_tuple_populater.h | 3 +- .../anf_populater/anf_matmul_populater.cc | 12 +- .../anf_populater/anf_matmul_populater.h | 3 +- .../anf_populater/anf_mul_populater.cc | 6 +- .../anf_populater/anf_mul_populater.h | 3 +- .../anf_populater/anf_node_populater.h | 5 +- .../anf_populater/anf_pool_populater.cc | 6 +- .../anf_populater/anf_pool_populater.h | 2 +- .../anf_populater/anf_quant_populater.cc | 6 +- .../anf_populater/anf_quant_populater.h | 2 +- .../anf_populater/anf_reducemean_populater.cc | 6 +- .../anf_populater/anf_reducemean_populater.h | 2 +- .../anf_populater/anf_reshape_populater.cc | 6 +- .../anf_populater/anf_reshape_populater.h | 2 +- .../anf_populater/anf_tensoradd_populater.cc | 6 +- .../anf_populater/anf_tensoradd_populater.h | 2 +- .../anf_populater/anf_transpose_populater.cc | 6 +- .../anf_populater/anf_transpose_populater.h | 2 +- .../anf_tuple_getitem_populater.cc | 6 +- .../anf_tuple_getitem_populater.h | 2 +- .../anf_importer/import_from_meta_graphT.cc | 33 ++-- .../anf_importer/import_from_meta_graphT.h | 2 +- .../anf_importer/import_from_protobuf.cc | 52 +++--- mindspore/lite/tools/benchmark/benchmark.h | 3 +- mindspore/lite/tools/converter/CMakeLists.txt | 1 + .../lite/tools/converter/anf_transform.cc | 5 +- .../quantizer/post_training_quantizer.cc | 34 ++-- .../quantizer/post_training_quantizer.h | 9 +- .../tools/converter/quantizer/quant_cast.cc | 14 +- .../converter/quantizer/quantize_util.cc | 14 +- .../tools/converter/quantizer/quantize_util.h | 10 +- .../lite/tools/optimizer/common/gllo_utils.cc | 20 +-- .../lite/tools/optimizer/common/gllo_utils.h | 4 +- .../fusion/constant_folding_fusion.cc | 7 +- .../fusion/conv_activation_fusion.cc | 10 +- .../optimizer/fusion/conv_biasadd_fusion.cc | 8 +- .../tools/optimizer/fusion/conv_bn_fusion.cc | 4 +- .../optimizer/fusion/conv_scale_fusion.cc | 4 +- .../optimizer/fusion/conv_transform_fusion.cc | 12 +- .../lite/tools/time_profile/time_profile.h | 7 +- 321 files changed, 2413 insertions(+), 2580 deletions(-) delete mode 100644 mindspore/lite/src/ir/primitive_t_value.cc delete mode 100644 mindspore/lite/src/ir/primitive_t_value.h diff --git a/mindspore/lite/include/model.h b/mindspore/lite/include/model.h index 2a880ca0ae7..7e1a005414d 100644 --- a/mindspore/lite/include/model.h +++ b/mindspore/lite/include/model.h @@ -20,7 +20,7 @@ #include #include #include -#include "schema/model_generated.h" +#include "src/ops/primitive_c.h" namespace mindspore { #define MS_API __attribute__((visibility("default"))) diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index 34751013b2e..6aa01502064 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -32,11 +32,10 @@ set(ANF_SRC ${ANF_SRC} ${CMAKE_CURRENT_SOURCE_DIR}/ir/meta_tensor_extends.cc ) - -add_library(mindspore-lite SHARED ${LITE_SRC} ${ANF_SRC}) +file(GLOB_RECURSE C_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc) +add_library(mindspore-lite SHARED ${LITE_SRC} ${ANF_SRC} ${C_OPS_SRC}) target_link_libraries(mindspore-lite cpu_kernel_mid_ - c_ops_mid ) add_subdirectory(runtime/kernel/arm) diff --git a/mindspore/lite/src/ir/primitive_t_value.cc b/mindspore/lite/src/ir/primitive_t_value.cc deleted file mode 100644 index ecb757c6eeb..00000000000 --- a/mindspore/lite/src/ir/primitive_t_value.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/ir/primitive_t_value.h" - -namespace mindspore::lite { -std::shared_ptr GetReturnPrim() { - auto return_primitiveT = new schema::PrimitiveT; - return_primitiveT->value.type = schema::PrimitiveType_Return; - return_primitiveT->value.value = new schema::ReturnT; - return std::make_shared(return_primitiveT); -} - -std::shared_ptr GetMakeTuplePrim() { - auto make_tuple_primitiveT = new schema::PrimitiveT; - make_tuple_primitiveT->value.type = schema::PrimitiveType_MakeTuple; - make_tuple_primitiveT->value.value = new schema::MakeTupleT; - return std::make_shared(make_tuple_primitiveT); -} - -std::shared_ptr GetTupleGetItemPrim() { - auto tuple_get_item_primitiveT = new schema::PrimitiveT(); - tuple_get_item_primitiveT->value.type = schema::PrimitiveType_TupleGetItem; - tuple_get_item_primitiveT->value.value = new schema::TupleGetItemT; - return std::make_shared(tuple_get_item_primitiveT); -} -} // namespace mindspore::lite diff --git a/mindspore/lite/src/ir/primitive_t_value.h b/mindspore/lite/src/ir/primitive_t_value.h deleted file mode 100644 index c3363014ab3..00000000000 --- a/mindspore/lite/src/ir/primitive_t_value.h +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_ -#define MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_ - -#include -#include -#include "schema/inner/model_generated.h" -#include "ir/value.h" - -namespace mindspore::lite { - -class PrimitiveTValue : public Value { - public: - explicit PrimitiveTValue(schema::PrimitiveT *primt) : primitive(primt) {} - // not responsible to free primitive, the one created the dynamic memory is responsible to free it. - ~PrimitiveTValue() override = default; - - MS_DECLARE_PARENT(PrimitiveTValue, Value) - - schema::PrimitiveT *GetPrimitiveT() const { return this->primitive; } - - void SetPrimitiveT(schema::PrimitiveT *primIn) { this->primitive = primIn; } - - bool operator==(const Value &rhs) const override { - if (rhs.isa()) { - auto other_prim = static_cast(rhs); - auto a = this->primitive->value.type; - auto b = other_prim.primitive->value.type; - return a == b; - } else { - return false; - } - } - - void SetInputQuantParam(const std::vector> &input_quant_param) { - this->input_quant_param_ = input_quant_param; - } - - void SetOutputQuantParam(const std::vector> &output_quant_param) { - this->output_quant_param_ = output_quant_param; - } - - void ClearInputOutputQuantParam() { - input_quant_param_.clear(); - output_quant_param_.clear(); - } - - void AddInputQuantParam(std::vector quant_param) { - this->input_quant_param_.emplace_back(quant_param); - } - std::vector> GetInputQuantParams() const { return input_quant_param_; } - - void AddOutputQuantParam(std::vector quant_param) { - this->output_quant_param_.emplace_back(quant_param); - } - std::vector> GetOutputQuantParams() const { return output_quant_param_; } - - void SetQuantType(schema::QuantType quant_type) { this->quant_type_ = quant_type; } - - schema::QuantType GetQuantType() const { return quant_type_; } - - protected: - schema::PrimitiveT *primitive = nullptr; - std::vector> input_quant_param_; - std::vector> output_quant_param_; - schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; -}; - -std::shared_ptr GetReturnPrim(); - -std::shared_ptr GetMakeTuplePrim(); - -std::shared_ptr GetTupleGetItemPrim(); -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_ diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index 7eee4dfb32b..07adf896928 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -21,11 +21,11 @@ #ifdef ENABLE_ARM #include #endif +#include "src/ops/primitive_c.h" #include "src/runtime/kernel/arm/nnacl/op_base.h" #include "include/context.h" #include "src/ir/tensor.h" #include "include/errorcode.h" -#include "src/ops/primitive_c.h" #ifdef ENABLE_FP16 using FLOAT_t = float16_t; diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index d2ba229d17a..4b8946ace42 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -14,9 +14,9 @@ * limitations under the License. */ +#include "src/lite_session.h" #include #include "include/errorcode.h" -#include "src/lite_session.h" #include "utils/log_adapter.h" #include "src/scheduler.h" #include "src/runtime/runtime_api.h" @@ -76,6 +76,7 @@ int LiteSession::ConvertTensors(const lite::Model *model) { this->tensors_.emplace_back(dstTensor); } + return RET_OK; } diff --git a/mindspore/lite/src/lite_session.h b/mindspore/lite/src/lite_session.h index bada7c3f93d..c034ade083d 100644 --- a/mindspore/lite/src/lite_session.h +++ b/mindspore/lite/src/lite_session.h @@ -21,11 +21,11 @@ #include #include #include +#include "src/lite_kernel.h" #include "include/ms_tensor.h" #include "include/lite_session.h" #include "include/model.h" #include "include/context.h" -#include "src/lite_kernel.h" #include "schema/model_generated.h" #include "src/executor.h" diff --git a/mindspore/lite/src/model.cc b/mindspore/lite/src/model.cc index 6059fb3eff8..48a84f177cd 100644 --- a/mindspore/lite/src/model.cc +++ b/mindspore/lite/src/model.cc @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "include/model.h" #include "src/ops/unique.h" #include "src/ops/space_to_batch.h" #include "src/ops/conv2d.h" @@ -106,8 +107,6 @@ #include "src/ops/squared_difference.h" #include "src/ops/ceil.h" #include "src/ops/round.h" -#include "src/ops/primitive_c.h" -#include "include/model.h" #include "utils/log_adapter.h" namespace mindspore::lite { diff --git a/mindspore/lite/src/ops/CMakeLists.txt b/mindspore/lite/src/ops/CMakeLists.txt index 06ad3db3f33..e69de29bb2d 100644 --- a/mindspore/lite/src/ops/CMakeLists.txt +++ b/mindspore/lite/src/ops/CMakeLists.txt @@ -1,3 +0,0 @@ -file(GLOB_RECURSE C_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc) - -add_library(c_ops_mid OBJECT ${C_OPS_SRC}) \ No newline at end of file diff --git a/mindspore/lite/src/ops/abs.h b/mindspore/lite/src/ops/abs.h index 82deee4452d..faac704bd23 100644 --- a/mindspore/lite/src/ops/abs.h +++ b/mindspore/lite/src/ops/abs.h @@ -32,7 +32,10 @@ namespace mindspore { namespace lite { class Abs : public ArithmeticSelf { public: - explicit Abs(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Abs(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} +#endif + explicit Abs(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/activation.cc b/mindspore/lite/src/ops/activation.cc index f983b3d7723..608241cfa52 100644 --- a/mindspore/lite/src/ops/activation.cc +++ b/mindspore/lite/src/ops/activation.cc @@ -19,16 +19,16 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Activation::GetType() const { return this->primitive->value.AsActivation()->type; } -float Activation::GetAlpha() const { return this->primitive->value.AsActivation()->alpha; } +int Activation::GetType() const { return this->primitive_->value.AsActivation()->type; } +float Activation::GetAlpha() const { return this->primitive_->value.AsActivation()->alpha; } -void Activation::SetType(int type) { this->primitive->value.AsActivation()->type = (schema::ActivationType)type; } -void Activation::SetAlpha(float alpha) { this->primitive->value.AsActivation()->alpha = alpha; } +void Activation::SetType(int type) { this->primitive_->value.AsActivation()->type = (schema::ActivationType)type; } +void Activation::SetAlpha(float alpha) { this->primitive_->value.AsActivation()->alpha = alpha; } #else -int Activation::GetType() const { return this->primitive->value_as_Activation()->type(); } -float Activation::GetAlpha() const { return this->primitive->value_as_Activation()->alpha(); } +int Activation::GetType() const { return this->primitive_->value_as_Activation()->type(); } +float Activation::GetAlpha() const { return this->primitive_->value_as_Activation()->alpha(); } void Activation::SetType(int type) {} void Activation::SetAlpha(float alpha) {} diff --git a/mindspore/lite/src/ops/activation.h b/mindspore/lite/src/ops/activation.h index a05fa5700ad..a4ab25c1d03 100644 --- a/mindspore/lite/src/ops/activation.h +++ b/mindspore/lite/src/ops/activation.h @@ -13,26 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_H_ +#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_H_ #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_H_ -#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_H_ namespace mindspore { namespace lite { class Activation : public PrimitiveC { public: - explicit Activation(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Activation(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Activation(schema::Primitive *primitive) : PrimitiveC(primitive) {} int GetType() const; float GetAlpha() const; void SetType(int type); diff --git a/mindspore/lite/src/ops/activation_grad.cc b/mindspore/lite/src/ops/activation_grad.cc index a27d41ccbc4..6ac7d9181e1 100644 --- a/mindspore/lite/src/ops/activation_grad.cc +++ b/mindspore/lite/src/ops/activation_grad.cc @@ -19,15 +19,15 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int ActivationGrad::GetType() const { return this->primitive->value.AsActivationGrad()->type; } +int ActivationGrad::GetType() const { return this->primitive_->value.AsActivationGrad()->type; } void ActivationGrad::SetType(int type) { - this->primitive->value.AsActivationGrad()->type = (schema::ActivationGradType)type; + this->primitive_->value.AsActivationGrad()->type = (schema::ActivationGradType)type; } #else -int ActivationGrad::GetType() const { return this->primitive->value_as_ActivationGrad()->type(); } +int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); } void ActivationGrad::SetType(int type) {} #endif diff --git a/mindspore/lite/src/ops/activation_grad.h b/mindspore/lite/src/ops/activation_grad.h index f0fd0a329a5..64e01a90a98 100644 --- a/mindspore/lite/src/ops/activation_grad.h +++ b/mindspore/lite/src/ops/activation_grad.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_ +#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_ -#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_ namespace mindspore { namespace lite { class ActivationGrad : public PrimitiveC { public: - explicit ActivationGrad(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit ActivationGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {} int GetType() const; void SetType(int type); diff --git a/mindspore/lite/src/ops/add.cc b/mindspore/lite/src/ops/add.cc index 5a51f15760f..9328cbd4e7f 100644 --- a/mindspore/lite/src/ops/add.cc +++ b/mindspore/lite/src/ops/add.cc @@ -19,15 +19,15 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Add::GetActivationType() const { return this->primitive->value.AsAdd()->activationType; } +int Add::GetActivationType() const { return this->primitive_->value.AsAdd()->activationType; } void Add::SetActivationType(int activation_type) { - this->primitive->value.AsAdd()->activationType = (schema::ActivationType)activation_type; + this->primitive_->value.AsAdd()->activationType = (schema::ActivationType)activation_type; } #else -int Add::GetActivationType() const { return this->primitive->value_as_Add()->activationType(); } +int Add::GetActivationType() const { return this->primitive_->value_as_Add()->activationType(); } void Add::SetActivationType(int activation_type) {} #endif diff --git a/mindspore/lite/src/ops/add.h b/mindspore/lite/src/ops/add.h index 444ef326522..93c79cccd39 100644 --- a/mindspore/lite/src/ops/add.h +++ b/mindspore/lite/src/ops/add.h @@ -14,6 +14,9 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_ADD_H_ +#define LITE_MINDSPORE_LITE_C_OPS_ADD_H_ + #include #include #include @@ -24,14 +27,15 @@ #else #include "schema/model_generated.h" #endif -#ifndef LITE_MINDSPORE_LITE_C_OPS_ADD_H_ -#define LITE_MINDSPORE_LITE_C_OPS_ADD_H_ namespace mindspore { namespace lite { class Add : public Arithmetic { public: - explicit Add(OriginPrimitive *primitive) : Arithmetic(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Add(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + explicit Add(schema::Primitive *primitive) : Arithmetic(primitive) {} int GetActivationType() const; void SetActivationType(int activation_type); diff --git a/mindspore/lite/src/ops/addn.cc b/mindspore/lite/src/ops/addn.cc index 5795384365e..6562f03df0d 100644 --- a/mindspore/lite/src/ops/addn.cc +++ b/mindspore/lite/src/ops/addn.cc @@ -19,13 +19,13 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int AddN::GetN() const { return this->primitive->value.AsAddN()->N; } +int AddN::GetN() const { return this->primitive_->value.AsAddN()->N; } -void AddN::SetN(int n) { this->primitive->value.AsAddN()->N = n; } +void AddN::SetN(int n) { this->primitive_->value.AsAddN()->N = n; } #else -int AddN::GetN() const { return this->primitive->value_as_AddN()->N(); } +int AddN::GetN() const { return this->primitive_->value_as_AddN()->N(); } void AddN::SetN(int n) {} #endif @@ -34,7 +34,7 @@ namespace { constexpr int kLeastInputNum = 2; } int AddN::InferShape(std::vector inputs, std::vector outputs) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs.front(); MS_ASSERT(input != nullptr); auto output = outputs.front(); diff --git a/mindspore/lite/src/ops/addn.h b/mindspore/lite/src/ops/addn.h index 8de647e12d2..bd50b7d8502 100644 --- a/mindspore/lite/src/ops/addn.h +++ b/mindspore/lite/src/ops/addn.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_ +#define LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_ -#define LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_ namespace mindspore { namespace lite { class AddN : public PrimitiveC { public: - explicit AddN(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit AddN(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit AddN(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetN() const; void SetN(int n); diff --git a/mindspore/lite/src/ops/argmax.cc b/mindspore/lite/src/ops/argmax.cc index 9d7a7d225f2..3bdb91ef67d 100644 --- a/mindspore/lite/src/ops/argmax.cc +++ b/mindspore/lite/src/ops/argmax.cc @@ -19,25 +19,25 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int ArgMax::GetAxis() const { return this->primitive->value.AsArgMax()->axis; } -bool ArgMax::GetOutMaxValue() const { return this->primitive->value.AsArgMax()->outMaxValue; } -int ArgMax::GetTopK() const { return this->primitive->value.AsArgMax()->topK; } -bool ArgMax::GetKeepDims() const { return this->primitive->value.AsArgMax()->keepDims; } -int ArgMax::GetAxisType() const { return this->primitive->value.AsArgMax()->axisType; } +int ArgMax::GetAxis() const { return this->primitive_->value.AsArgMax()->axis; } +bool ArgMax::GetOutMaxValue() const { return this->primitive_->value.AsArgMax()->outMaxValue; } +int ArgMax::GetTopK() const { return this->primitive_->value.AsArgMax()->topK; } +bool ArgMax::GetKeepDims() const { return this->primitive_->value.AsArgMax()->keepDims; } +int ArgMax::GetAxisType() const { return this->primitive_->value.AsArgMax()->axisType; } -void ArgMax::SetAxis(int axis) { this->primitive->value.AsArgMax()->axis = axis; } -void ArgMax::SetOutMaxValue(bool out_max_value) { this->primitive->value.AsArgMax()->outMaxValue = out_max_value; } -void ArgMax::SetTopK(int top_k) { this->primitive->value.AsArgMax()->topK = top_k; } -void ArgMax::SetKeepDims(bool keep_dims) { this->primitive->value.AsArgMax()->keepDims = keep_dims; } -void ArgMax::SetAxisType(int axis_type) { this->primitive->value.AsArgMax()->axisType = axis_type; } +void ArgMax::SetAxis(int axis) { this->primitive_->value.AsArgMax()->axis = axis; } +void ArgMax::SetOutMaxValue(bool out_max_value) { this->primitive_->value.AsArgMax()->outMaxValue = out_max_value; } +void ArgMax::SetTopK(int top_k) { this->primitive_->value.AsArgMax()->topK = top_k; } +void ArgMax::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMax()->keepDims = keep_dims; } +void ArgMax::SetAxisType(int axis_type) { this->primitive_->value.AsArgMax()->axisType = axis_type; } #else -int ArgMax::GetAxis() const { return this->primitive->value_as_ArgMax()->axis(); } -bool ArgMax::GetOutMaxValue() const { return this->primitive->value_as_ArgMax()->outMaxValue(); } -int ArgMax::GetTopK() const { return this->primitive->value_as_ArgMax()->topK(); } -bool ArgMax::GetKeepDims() const { return this->primitive->value_as_ArgMax()->keepDims(); } -int ArgMax::GetAxisType() const { return this->primitive->value_as_ArgMax()->axisType(); } +int ArgMax::GetAxis() const { return this->primitive_->value_as_ArgMax()->axis(); } +bool ArgMax::GetOutMaxValue() const { return this->primitive_->value_as_ArgMax()->outMaxValue(); } +int ArgMax::GetTopK() const { return this->primitive_->value_as_ArgMax()->topK(); } +bool ArgMax::GetKeepDims() const { return this->primitive_->value_as_ArgMax()->keepDims(); } +int ArgMax::GetAxisType() const { return this->primitive_->value_as_ArgMax()->axisType(); } void ArgMax::SetAxis(int axis) {} void ArgMax::SetOutMaxValue(bool out_max_value) {} @@ -47,7 +47,7 @@ void ArgMax::SetAxisType(int axis_type) {} #endif int ArgMax::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); diff --git a/mindspore/lite/src/ops/argmax.h b/mindspore/lite/src/ops/argmax.h index a80566d97a7..ed89d1a3a6e 100644 --- a/mindspore/lite/src/ops/argmax.h +++ b/mindspore/lite/src/ops/argmax.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_ARG_MAX_H_ +#define LITE_MINDSPORE_LITE_C_OPS_ARG_MAX_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_ARG_MAX_H_ -#define LITE_MINDSPORE_LITE_C_OPS_ARG_MAX_H_ namespace mindspore { namespace lite { class ArgMax : public PrimitiveC { public: - explicit ArgMax(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit ArgMax(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit ArgMax(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; diff --git a/mindspore/lite/src/ops/argmin.cc b/mindspore/lite/src/ops/argmin.cc index e9d4599558d..b95042a1286 100644 --- a/mindspore/lite/src/ops/argmin.cc +++ b/mindspore/lite/src/ops/argmin.cc @@ -19,25 +19,25 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int ArgMin::GetAxis() const { return this->primitive->value.AsArgMin()->axis; } -bool ArgMin::GetOutMaxValue() const { return this->primitive->value.AsArgMin()->outMaxValue; } -int ArgMin::GetTopK() const { return this->primitive->value.AsArgMin()->topK; } -bool ArgMin::GetKeepDims() const { return this->primitive->value.AsArgMin()->keepDims; } -int ArgMin::GetAxisType() const { return this->primitive->value.AsArgMin()->axisType; } +int ArgMin::GetAxis() const { return this->primitive_->value.AsArgMin()->axis; } +bool ArgMin::GetOutMaxValue() const { return this->primitive_->value.AsArgMin()->outMaxValue; } +int ArgMin::GetTopK() const { return this->primitive_->value.AsArgMin()->topK; } +bool ArgMin::GetKeepDims() const { return this->primitive_->value.AsArgMin()->keepDims; } +int ArgMin::GetAxisType() const { return this->primitive_->value.AsArgMin()->axisType; } -void ArgMin::SetAxis(int axis) { this->primitive->value.AsArgMin()->axis = axis; } -void ArgMin::SetOutMaxValue(bool out_max_value) { this->primitive->value.AsArgMin()->outMaxValue = out_max_value; } -void ArgMin::SetTopK(int top_k) { this->primitive->value.AsArgMin()->topK = top_k; } -void ArgMin::SetKeepDims(bool keep_dims) { this->primitive->value.AsArgMin()->keepDims = keep_dims; } -void ArgMin::SetAxisType(int axis_type) { this->primitive->value.AsArgMin()->axisType = axis_type; } +void ArgMin::SetAxis(int axis) { this->primitive_->value.AsArgMin()->axis = axis; } +void ArgMin::SetOutMaxValue(bool out_max_value) { this->primitive_->value.AsArgMin()->outMaxValue = out_max_value; } +void ArgMin::SetTopK(int top_k) { this->primitive_->value.AsArgMin()->topK = top_k; } +void ArgMin::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMin()->keepDims = keep_dims; } +void ArgMin::SetAxisType(int axis_type) { this->primitive_->value.AsArgMin()->axisType = axis_type; } #else -int ArgMin::GetAxis() const { return this->primitive->value_as_ArgMin()->axis(); } -bool ArgMin::GetOutMaxValue() const { return this->primitive->value_as_ArgMin()->outMaxValue(); } -int ArgMin::GetTopK() const { return this->primitive->value_as_ArgMin()->topK(); } -bool ArgMin::GetKeepDims() const { return this->primitive->value_as_ArgMin()->keepDims(); } -int ArgMin::GetAxisType() const { return this->primitive->value_as_ArgMin()->axisType(); } +int ArgMin::GetAxis() const { return this->primitive_->value_as_ArgMin()->axis(); } +bool ArgMin::GetOutMaxValue() const { return this->primitive_->value_as_ArgMin()->outMaxValue(); } +int ArgMin::GetTopK() const { return this->primitive_->value_as_ArgMin()->topK(); } +bool ArgMin::GetKeepDims() const { return this->primitive_->value_as_ArgMin()->keepDims(); } +int ArgMin::GetAxisType() const { return this->primitive_->value_as_ArgMin()->axisType(); } void ArgMin::SetAxis(int axis) {} void ArgMin::SetOutMaxValue(bool out_max_value) {} @@ -47,7 +47,7 @@ void ArgMin::SetAxisType(int axis_type) {} #endif int ArgMin::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); diff --git a/mindspore/lite/src/ops/argmin.h b/mindspore/lite/src/ops/argmin.h index adb5c92bf3c..35b6a3b696f 100644 --- a/mindspore/lite/src/ops/argmin.h +++ b/mindspore/lite/src/ops/argmin.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_ARG_MIN_H_ +#define LITE_MINDSPORE_LITE_C_OPS_ARG_MIN_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_ARG_MIN_H_ -#define LITE_MINDSPORE_LITE_C_OPS_ARG_MIN_H_ namespace mindspore { namespace lite { class ArgMin : public PrimitiveC { public: - explicit ArgMin(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit ArgMin(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit ArgMin(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; diff --git a/mindspore/lite/src/ops/arithmetic.cc b/mindspore/lite/src/ops/arithmetic.cc index 9e34d9d6a54..fa352ce08e6 100644 --- a/mindspore/lite/src/ops/arithmetic.cc +++ b/mindspore/lite/src/ops/arithmetic.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace lite { int Arithmetic::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); if (inputs_.size() != kDoubleNum) { MS_LOG(ERROR) << "The number of input must be " << kDoubleNum; return RET_INPUT_TENSOR_ERROR; diff --git a/mindspore/lite/src/ops/arithmetic.h b/mindspore/lite/src/ops/arithmetic.h index 880d5422263..eee942eb541 100644 --- a/mindspore/lite/src/ops/arithmetic.h +++ b/mindspore/lite/src/ops/arithmetic.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_H_ +#define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_H_ -#define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_H_ namespace mindspore { namespace lite { class Arithmetic : public PrimitiveC { public: - explicit Arithmetic(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Arithmetic(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; bool Broadcasting() { return this->broadcasting_; } diff --git a/mindspore/lite/src/ops/arithmetic_self.cc b/mindspore/lite/src/ops/arithmetic_self.cc index f4facb76924..9a4fa1546d1 100644 --- a/mindspore/lite/src/ops/arithmetic_self.cc +++ b/mindspore/lite/src/ops/arithmetic_self.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace lite { int ArithmeticSelf::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); diff --git a/mindspore/lite/src/ops/arithmetic_self.h b/mindspore/lite/src/ops/arithmetic_self.h index 3cc7a748ccb..8ecf5c8899e 100644 --- a/mindspore/lite/src/ops/arithmetic_self.h +++ b/mindspore/lite/src/ops/arithmetic_self.h @@ -14,24 +14,20 @@ * limitations under the License. */ -#include -#include -#include -#include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif #ifndef LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_SELF_H_ #define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_SELF_H_ +#include +#include "src/ops/primitive_c.h" + namespace mindspore { namespace lite { class ArithmeticSelf : public PrimitiveC { public: - explicit ArithmeticSelf(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit ArithmeticSelf(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit ArithmeticSelf(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; }; diff --git a/mindspore/lite/src/ops/batch_norm.cc b/mindspore/lite/src/ops/batch_norm.cc index fc7026ce110..dc7e60015bd 100644 --- a/mindspore/lite/src/ops/batch_norm.cc +++ b/mindspore/lite/src/ops/batch_norm.cc @@ -19,13 +19,13 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -float BatchNorm::GetEpsilon() const { return this->primitive->value.AsBatchNorm()->epsilon; } +float BatchNorm::GetEpsilon() const { return this->primitive_->value.AsBatchNorm()->epsilon; } -void BatchNorm::SetEpsilon(float epsilon) { this->primitive->value.AsBatchNorm()->epsilon = epsilon; } +void BatchNorm::SetEpsilon(float epsilon) { this->primitive_->value.AsBatchNorm()->epsilon = epsilon; } #else -float BatchNorm::GetEpsilon() const { return this->primitive->value_as_BatchNorm()->epsilon(); } +float BatchNorm::GetEpsilon() const { return this->primitive_->value_as_BatchNorm()->epsilon(); } void BatchNorm::SetEpsilon(float epsilon) {} #endif diff --git a/mindspore/lite/src/ops/batch_norm.h b/mindspore/lite/src/ops/batch_norm.h index 8745f96503b..4f3a10a9252 100644 --- a/mindspore/lite/src/ops/batch_norm.h +++ b/mindspore/lite/src/ops/batch_norm.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_BATCH_NORM_H_ +#define LITE_MINDSPORE_LITE_C_OPS_BATCH_NORM_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_BATCH_NORM_H_ -#define LITE_MINDSPORE_LITE_C_OPS_BATCH_NORM_H_ namespace mindspore { namespace lite { class BatchNorm : public PrimitiveC { public: - explicit BatchNorm(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit BatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit BatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {} float GetEpsilon() const; void SetEpsilon(float epsilon); diff --git a/mindspore/lite/src/ops/batch_to_space.cc b/mindspore/lite/src/ops/batch_to_space.cc index cd1c897b3b4..c11a5ffd20f 100644 --- a/mindspore/lite/src/ops/batch_to_space.cc +++ b/mindspore/lite/src/ops/batch_to_space.cc @@ -23,22 +23,22 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector BatchToSpace::GetBlockShape() const { return this->primitive->value.AsBatchToSpace()->blockShape; } -std::vector BatchToSpace::GetCrops() const { return this->primitive->value.AsBatchToSpace()->crops; } +std::vector BatchToSpace::GetBlockShape() const { return this->primitive_->value.AsBatchToSpace()->blockShape; } +std::vector BatchToSpace::GetCrops() const { return this->primitive_->value.AsBatchToSpace()->crops; } void BatchToSpace::SetBlockShape(const std::vector &block_shape) { - this->primitive->value.AsBatchToSpace()->blockShape = block_shape; + this->primitive_->value.AsBatchToSpace()->blockShape = block_shape; } -void BatchToSpace::SetCrops(const std::vector &crops) { this->primitive->value.AsBatchToSpace()->crops = crops; } +void BatchToSpace::SetCrops(const std::vector &crops) { this->primitive_->value.AsBatchToSpace()->crops = crops; } #else std::vector BatchToSpace::GetBlockShape() const { - auto fb_vector = this->primitive->value_as_BatchToSpace()->blockShape(); + auto fb_vector = this->primitive_->value_as_BatchToSpace()->blockShape(); return std::vector(fb_vector->begin(), fb_vector->end()); } std::vector BatchToSpace::GetCrops() const { - auto fb_vector = this->primitive->value_as_BatchToSpace()->crops(); + auto fb_vector = this->primitive_->value_as_BatchToSpace()->crops(); return std::vector(fb_vector->begin(), fb_vector->end()); } @@ -53,7 +53,7 @@ constexpr int kCropsSize = 4; } // namespace int BatchToSpace::InferShape(std::vector inputs, std::vector outputs) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); if (outputs.size() != kBatchToSpaceOutputNum || inputs.size() != kBatchToSpaceInputNum) { MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); return RET_PARAM_INVALID; diff --git a/mindspore/lite/src/ops/batch_to_space.h b/mindspore/lite/src/ops/batch_to_space.h index 18c3ff239ce..e107803b033 100644 --- a/mindspore/lite/src/ops/batch_to_space.h +++ b/mindspore/lite/src/ops/batch_to_space.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_BATCH_TO_SPACE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_BATCH_TO_SPACE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_BATCH_TO_SPACE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_BATCH_TO_SPACE_H_ namespace mindspore { namespace lite { class BatchToSpace : public PrimitiveC { public: - explicit BatchToSpace(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit BatchToSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit BatchToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetBlockShape() const; diff --git a/mindspore/lite/src/ops/bias_add.cc b/mindspore/lite/src/ops/bias_add.cc index 24c0f707ad3..4cada2551cf 100644 --- a/mindspore/lite/src/ops/bias_add.cc +++ b/mindspore/lite/src/ops/bias_add.cc @@ -19,14 +19,14 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector BiasAdd::GetAxis() const { return this->primitive->value.AsBiasAdd()->axis; } +std::vector BiasAdd::GetAxis() const { return this->primitive_->value.AsBiasAdd()->axis; } -void BiasAdd::SetAxis(const std::vector &axis) { this->primitive->value.AsBiasAdd()->axis = axis; } +void BiasAdd::SetAxis(const std::vector &axis) { this->primitive_->value.AsBiasAdd()->axis = axis; } #else std::vector BiasAdd::GetAxis() const { - auto fb_vector = this->primitive->value_as_BiasAdd()->axis(); + auto fb_vector = this->primitive_->value_as_BiasAdd()->axis(); return std::vector(fb_vector->begin(), fb_vector->end()); } diff --git a/mindspore/lite/src/ops/bias_add.h b/mindspore/lite/src/ops/bias_add.h index e9d7ad0a60a..4526d0b3c74 100644 --- a/mindspore/lite/src/ops/bias_add.h +++ b/mindspore/lite/src/ops/bias_add.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_BIAS_ADD_H_ +#define LITE_MINDSPORE_LITE_C_OPS_BIAS_ADD_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_BIAS_ADD_H_ -#define LITE_MINDSPORE_LITE_C_OPS_BIAS_ADD_H_ namespace mindspore { namespace lite { class BiasAdd : public PrimitiveC { public: - explicit BiasAdd(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit BiasAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit BiasAdd(schema::Primitive *primitive) : PrimitiveC(primitive) {} std::vector GetAxis() const; void SetAxis(const std::vector &axis); diff --git a/mindspore/lite/src/ops/bias_grad.cc b/mindspore/lite/src/ops/bias_grad.cc index 9912a4747fb..6fc1caa616c 100644 --- a/mindspore/lite/src/ops/bias_grad.cc +++ b/mindspore/lite/src/ops/bias_grad.cc @@ -19,14 +19,14 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector BiasGrad::GetAxis() const { return this->primitive->value.AsBiasGrad()->axis; } +std::vector BiasGrad::GetAxis() const { return this->primitive_->value.AsBiasGrad()->axis; } -void BiasGrad::SetAxis(const std::vector &axis) { this->primitive->value.AsBiasGrad()->axis = axis; } +void BiasGrad::SetAxis(const std::vector &axis) { this->primitive_->value.AsBiasGrad()->axis = axis; } #else std::vector BiasGrad::GetAxis() const { - auto fb_vector = this->primitive->value_as_BiasGrad()->axis(); + auto fb_vector = this->primitive_->value_as_BiasGrad()->axis(); return std::vector(fb_vector->begin(), fb_vector->end()); } diff --git a/mindspore/lite/src/ops/bias_grad.h b/mindspore/lite/src/ops/bias_grad.h index bf135f3d0ac..9d6cfea64e8 100644 --- a/mindspore/lite/src/ops/bias_grad.h +++ b/mindspore/lite/src/ops/bias_grad.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_ +#define LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_ -#define LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_ namespace mindspore { namespace lite { class BiasGrad : public PrimitiveC { public: - explicit BiasGrad(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit BiasGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit BiasGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {} std::vector GetAxis() const; void SetAxis(const std::vector &axis); diff --git a/mindspore/lite/src/ops/bn_grad_input.cc b/mindspore/lite/src/ops/bn_grad_input.cc index cbc851f887d..1736e1fe9c6 100644 --- a/mindspore/lite/src/ops/bn_grad_input.cc +++ b/mindspore/lite/src/ops/bn_grad_input.cc @@ -19,16 +19,16 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -float BNGradInput::GetEps() const { return this->primitive->value.AsBNGradInput()->eps; } -int BNGradInput::GetChannels() const { return this->primitive->value.AsBNGradInput()->channels; } +float BNGradInput::GetEps() const { return this->primitive_->value.AsBNGradInput()->eps; } +int BNGradInput::GetChannels() const { return this->primitive_->value.AsBNGradInput()->channels; } -void BNGradInput::SetEps(float eps) { this->primitive->value.AsBNGradInput()->eps = eps; } -void BNGradInput::SetChannels(int channels) { this->primitive->value.AsBNGradInput()->channels = channels; } +void BNGradInput::SetEps(float eps) { this->primitive_->value.AsBNGradInput()->eps = eps; } +void BNGradInput::SetChannels(int channels) { this->primitive_->value.AsBNGradInput()->channels = channels; } #else -float BNGradInput::GetEps() const { return this->primitive->value_as_BNGradInput()->eps(); } -int BNGradInput::GetChannels() const { return this->primitive->value_as_BNGradInput()->channels(); } +float BNGradInput::GetEps() const { return this->primitive_->value_as_BNGradInput()->eps(); } +int BNGradInput::GetChannels() const { return this->primitive_->value_as_BNGradInput()->channels(); } void BNGradInput::SetEps(float eps) {} void BNGradInput::SetChannels(int channels) {} diff --git a/mindspore/lite/src/ops/bn_grad_input.h b/mindspore/lite/src/ops/bn_grad_input.h index 04764696991..5a138439300 100644 --- a/mindspore/lite/src/ops/bn_grad_input.h +++ b/mindspore/lite/src/ops/bn_grad_input.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ +#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ -#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_ namespace mindspore { namespace lite { class BNGradInput : public PrimitiveC { public: - explicit BNGradInput(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit BNGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit BNGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {} float GetEps() const; int GetChannels() const; diff --git a/mindspore/lite/src/ops/broadcast_to.cc b/mindspore/lite/src/ops/broadcast_to.cc index 38b3cc64518..ca2d71607d2 100644 --- a/mindspore/lite/src/ops/broadcast_to.cc +++ b/mindspore/lite/src/ops/broadcast_to.cc @@ -19,16 +19,16 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector BroadcastTo::GetDstShape() const { return this->primitive->value.AsBroadcastTo()->dst_shape; } +std::vector BroadcastTo::GetDstShape() const { return this->primitive_->value.AsBroadcastTo()->dst_shape; } void BroadcastTo::SetDstShape(const std::vector &dst_shape) { - this->primitive->value.AsBroadcastTo()->dst_shape = dst_shape; + this->primitive_->value.AsBroadcastTo()->dst_shape = dst_shape; } #else std::vector BroadcastTo::GetDstShape() const { - auto fb_vector = this->primitive->value_as_BroadcastTo()->dst_shape(); + auto fb_vector = this->primitive_->value_as_BroadcastTo()->dst_shape(); return std::vector(fb_vector->begin(), fb_vector->end()); } diff --git a/mindspore/lite/src/ops/broadcast_to.h b/mindspore/lite/src/ops/broadcast_to.h index 8a268f32928..2afaf9e1e87 100644 --- a/mindspore/lite/src/ops/broadcast_to.h +++ b/mindspore/lite/src/ops/broadcast_to.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_BROADCAST_TO_H_ +#define LITE_MINDSPORE_LITE_C_OPS_BROADCAST_TO_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_BROADCAST_TO_H_ -#define LITE_MINDSPORE_LITE_C_OPS_BROADCAST_TO_H_ namespace mindspore { namespace lite { class BroadcastTo : public PrimitiveC { public: - explicit BroadcastTo(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit BroadcastTo(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit BroadcastTo(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetDstShape() const; diff --git a/mindspore/lite/src/ops/caffe_p_relu.cc b/mindspore/lite/src/ops/caffe_p_relu.cc index 0623e8abc64..c7a74cc18ed 100644 --- a/mindspore/lite/src/ops/caffe_p_relu.cc +++ b/mindspore/lite/src/ops/caffe_p_relu.cc @@ -19,15 +19,15 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -bool CaffePReLU::GetChannelShared() const { return this->primitive->value.AsCaffePReLU()->channelShared; } +bool CaffePReLU::GetChannelShared() const { return this->primitive_->value.AsCaffePReLU()->channelShared; } void CaffePReLU::SetChannelShared(bool channel_shared) { - this->primitive->value.AsCaffePReLU()->channelShared = channel_shared; + this->primitive_->value.AsCaffePReLU()->channelShared = channel_shared; } #else -bool CaffePReLU::GetChannelShared() const { return this->primitive->value_as_CaffePReLU()->channelShared(); } +bool CaffePReLU::GetChannelShared() const { return this->primitive_->value_as_CaffePReLU()->channelShared(); } void CaffePReLU::SetChannelShared(bool channel_shared) {} #endif diff --git a/mindspore/lite/src/ops/caffe_p_relu.h b/mindspore/lite/src/ops/caffe_p_relu.h index 76f44d52f51..92e7bf1feba 100644 --- a/mindspore/lite/src/ops/caffe_p_relu.h +++ b/mindspore/lite/src/ops/caffe_p_relu.h @@ -14,26 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_ +#define LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" #include "src/ops/activation.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_ -#define LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_ namespace mindspore { namespace lite { class CaffePReLU : public Activation { public: - explicit CaffePReLU(OriginPrimitive *primitive) : Activation(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit CaffePReLU(schema::PrimitiveT *primitive) : Activation(primitive) {} +#endif + explicit CaffePReLU(schema::Primitive *primitive) : Activation(primitive) {} bool GetChannelShared() const; void SetChannelShared(bool channel_shared); diff --git a/mindspore/lite/src/ops/cast.cc b/mindspore/lite/src/ops/cast.cc index 48d167f1a38..d7ba94ee00d 100644 --- a/mindspore/lite/src/ops/cast.cc +++ b/mindspore/lite/src/ops/cast.cc @@ -19,23 +19,23 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Cast::GetSrcT() const { return this->primitive->value.AsCast()->srcT; } -int Cast::GetDstT() const { return this->primitive->value.AsCast()->dstT; } +int Cast::GetSrcT() const { return this->primitive_->value.AsCast()->srcT; } +int Cast::GetDstT() const { return this->primitive_->value.AsCast()->dstT; } -void Cast::SetSrcT(int src_t) { this->primitive->value.AsCast()->srcT = src_t; } -void Cast::SetDstT(int dst_t) { this->primitive->value.AsCast()->dstT = dst_t; } +void Cast::SetSrcT(int src_t) { this->primitive_->value.AsCast()->srcT = src_t; } +void Cast::SetDstT(int dst_t) { this->primitive_->value.AsCast()->dstT = dst_t; } #else -int Cast::GetSrcT() const { return this->primitive->value_as_Cast()->srcT(); } -int Cast::GetDstT() const { return this->primitive->value_as_Cast()->dstT(); } +int Cast::GetSrcT() const { return this->primitive_->value_as_Cast()->srcT(); } +int Cast::GetDstT() const { return this->primitive_->value_as_Cast()->dstT(); } void Cast::SetSrcT(int src_t) {} void Cast::SetDstT(int dst_t) {} #endif int Cast::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); diff --git a/mindspore/lite/src/ops/cast.h b/mindspore/lite/src/ops/cast.h index cb58ee94fb7..45ceb800495 100644 --- a/mindspore/lite/src/ops/cast.h +++ b/mindspore/lite/src/ops/cast.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_CAST_H_ +#define LITE_MINDSPORE_LITE_C_OPS_CAST_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_CAST_H_ -#define LITE_MINDSPORE_LITE_C_OPS_CAST_H_ namespace mindspore { namespace lite { class Cast : public PrimitiveC { public: - explicit Cast(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Cast(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Cast(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetSrcT() const; diff --git a/mindspore/lite/src/ops/ceil.h b/mindspore/lite/src/ops/ceil.h index d4591a46bc6..43dbc344ac5 100644 --- a/mindspore/lite/src/ops/ceil.h +++ b/mindspore/lite/src/ops/ceil.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_CEIL_H_ +#define LITE_MINDSPORE_LITE_C_OPS_CEIL_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/arithmetic_self.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_CEIL_H_ -#define LITE_MINDSPORE_LITE_C_OPS_CEIL_H_ +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { class Ceil : public ArithmeticSelf { public: - explicit Ceil(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Ceil(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} +#endif + explicit Ceil(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/clip.cc b/mindspore/lite/src/ops/clip.cc index 6822872c161..656bd5c0f76 100644 --- a/mindspore/lite/src/ops/clip.cc +++ b/mindspore/lite/src/ops/clip.cc @@ -19,16 +19,16 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -float Clip::GetMax() const { return this->primitive->value.AsClip()->max; } -float Clip::GetMin() const { return this->primitive->value.AsClip()->min; } +float Clip::GetMax() const { return this->primitive_->value.AsClip()->max; } +float Clip::GetMin() const { return this->primitive_->value.AsClip()->min; } -void Clip::SetMax(float max) { this->primitive->value.AsClip()->max = max; } -void Clip::SetMin(float min) { this->primitive->value.AsClip()->min = min; } +void Clip::SetMax(float max) { this->primitive_->value.AsClip()->max = max; } +void Clip::SetMin(float min) { this->primitive_->value.AsClip()->min = min; } #else -float Clip::GetMax() const { return this->primitive->value_as_Clip()->max(); } -float Clip::GetMin() const { return this->primitive->value_as_Clip()->min(); } +float Clip::GetMax() const { return this->primitive_->value_as_Clip()->max(); } +float Clip::GetMin() const { return this->primitive_->value_as_Clip()->min(); } void Clip::SetMax(float max) {} void Clip::SetMin(float min) {} diff --git a/mindspore/lite/src/ops/clip.h b/mindspore/lite/src/ops/clip.h index 2414ae75f02..2cda7fa5129 100644 --- a/mindspore/lite/src/ops/clip.h +++ b/mindspore/lite/src/ops/clip.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_CLIP_H_ +#define LITE_MINDSPORE_LITE_C_OPS_CLIP_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_CLIP_H_ -#define LITE_MINDSPORE_LITE_C_OPS_CLIP_H_ namespace mindspore { namespace lite { class Clip : public PrimitiveC { public: - explicit Clip(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Clip(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Clip(schema::Primitive *primitive) : PrimitiveC(primitive) {} float GetMax() const; float GetMin() const; diff --git a/mindspore/lite/src/ops/concat.cc b/mindspore/lite/src/ops/concat.cc index 2e724b82d4f..5f087ea7a6b 100644 --- a/mindspore/lite/src/ops/concat.cc +++ b/mindspore/lite/src/ops/concat.cc @@ -21,16 +21,16 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Concat::GetAxis() const { return this->primitive->value.AsConcat()->axis; } -int Concat::GetN() const { return this->primitive->value.AsConcat()->n; } +int Concat::GetAxis() const { return this->primitive_->value.AsConcat()->axis; } +int Concat::GetN() const { return this->primitive_->value.AsConcat()->n; } -void Concat::SetAxis(int axis) { this->primitive->value.AsConcat()->axis = axis; } -void Concat::SetN(int n) { this->primitive->value.AsConcat()->n = n; } +void Concat::SetAxis(int axis) { this->primitive_->value.AsConcat()->axis = axis; } +void Concat::SetN(int n) { this->primitive_->value.AsConcat()->n = n; } #else -int Concat::GetAxis() const { return this->primitive->value_as_Concat()->axis(); } -int Concat::GetN() const { return this->primitive->value_as_Concat()->n(); } +int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); } +int Concat::GetN() const { return this->primitive_->value_as_Concat()->n(); } void Concat::SetAxis(int axis) {} void Concat::SetN(int n) {} @@ -40,7 +40,7 @@ namespace { constexpr int kConcatOutputNum = 1; } int Concat::InferShape(std::vector inputs_, std::vector outputs_) { - if (this->primitive == nullptr) { + if (this->primitive_ == nullptr) { MS_LOG(ERROR) << "primitive is nullptr!"; return RET_PARAM_INVALID; } diff --git a/mindspore/lite/src/ops/concat.h b/mindspore/lite/src/ops/concat.h index 981731bedab..189b1387aac 100644 --- a/mindspore/lite/src/ops/concat.h +++ b/mindspore/lite/src/ops/concat.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_CONCAT_H_ +#define LITE_MINDSPORE_LITE_C_OPS_CONCAT_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_CONCAT_H_ -#define LITE_MINDSPORE_LITE_C_OPS_CONCAT_H_ namespace mindspore { namespace lite { class Concat : public PrimitiveC { public: - explicit Concat(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Concat(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Concat(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; diff --git a/mindspore/lite/src/ops/constant_of_shape.cc b/mindspore/lite/src/ops/constant_of_shape.cc index ff4efebfde5..aa1dfb822a7 100644 --- a/mindspore/lite/src/ops/constant_of_shape.cc +++ b/mindspore/lite/src/ops/constant_of_shape.cc @@ -25,13 +25,13 @@ constexpr int kShapeInputNum = 1; constexpr int kShapeOutputNum = 1; } // namespace #ifdef PRIMITIVE_WRITEABLE -float ConstantOfShape::GetValue() const { return this->primitive->value.AsConstantOfShape()->value; } +float ConstantOfShape::GetValue() const { return this->primitive_->value.AsConstantOfShape()->value; } -void ConstantOfShape::SetValue(float value) { this->primitive->value.AsConstantOfShape()->value = value; } +void ConstantOfShape::SetValue(float value) { this->primitive_->value.AsConstantOfShape()->value = value; } #else -float ConstantOfShape::GetValue() const { return this->primitive->value_as_ConstantOfShape()->value(); } +float ConstantOfShape::GetValue() const { return this->primitive_->value_as_ConstantOfShape()->value(); } void ConstantOfShape::SetValue(float value) {} #endif diff --git a/mindspore/lite/src/ops/constant_of_shape.h b/mindspore/lite/src/ops/constant_of_shape.h index f67521cbfc1..34e1791efbc 100644 --- a/mindspore/lite/src/ops/constant_of_shape.h +++ b/mindspore/lite/src/ops/constant_of_shape.h @@ -14,24 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_SRC_OPS_CONSTANT_OF_SHAPE_H_ +#define LITE_MINDSPORE_LITE_SRC_OPS_CONSTANT_OF_SHAPE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif -#ifndef LITE_MINDSPORE_LITE_SRC_OPS_CONSTANT_OF_SHAPE_H_ -#define LITE_MINDSPORE_LITE_SRC_OPS_CONSTANT_OF_SHAPE_H_ namespace mindspore { namespace lite { class ConstantOfShape : public PrimitiveC { public: - explicit ConstantOfShape(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit ConstantOfShape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit ConstantOfShape(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; float GetValue() const; void SetValue(float value); diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index d0564a8f836..0a0baf98d7b 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -26,63 +26,63 @@ int Conv2D::PadDown() const { return this->pad_d_; } int Conv2D::PadLeft() const { return this->pad_l_; } int Conv2D::PadRight() const { return this->pad_r_; } #ifdef PRIMITIVE_WRITEABLE -int Conv2D::GetFormat() const { return this->primitive->value.AsConv2D()->format; } -int Conv2D::GetGroup() const { return this->primitive->value.AsConv2D()->group; } -int Conv2D::GetChannelIn() const { return this->primitive->value.AsConv2D()->channelIn; } -int Conv2D::GetChannelOut() const { return this->primitive->value.AsConv2D()->channelOut; } -int Conv2D::GetKernelW() const { return this->primitive->value.AsConv2D()->kernelW; } -int Conv2D::GetKernelH() const { return this->primitive->value.AsConv2D()->kernelH; } -int Conv2D::GetStrideW() const { return this->primitive->value.AsConv2D()->strideW; } -int Conv2D::GetStrideH() const { return this->primitive->value.AsConv2D()->strideH; } -int Conv2D::GetPadMode() const { return this->primitive->value.AsConv2D()->padMode; } -int Conv2D::GetPadUp() const { return this->primitive->value.AsConv2D()->padUp; } -int Conv2D::GetPadDown() const { return this->primitive->value.AsConv2D()->padDown; } -int Conv2D::GetPadLeft() const { return this->primitive->value.AsConv2D()->padLeft; } -int Conv2D::GetPadRight() const { return this->primitive->value.AsConv2D()->padRight; } -int Conv2D::GetDilateW() const { return this->primitive->value.AsConv2D()->dilateW; } -int Conv2D::GetDilateH() const { return this->primitive->value.AsConv2D()->dilateH; } -bool Conv2D::GetHasBias() const { return this->primitive->value.AsConv2D()->hasBias; } -int Conv2D::GetActivationType() const { return this->primitive->value.AsConv2D()->activationType; } +int Conv2D::GetFormat() const { return this->primitive_->value.AsConv2D()->format; } +int Conv2D::GetGroup() const { return this->primitive_->value.AsConv2D()->group; } +int Conv2D::GetChannelIn() const { return this->primitive_->value.AsConv2D()->channelIn; } +int Conv2D::GetChannelOut() const { return this->primitive_->value.AsConv2D()->channelOut; } +int Conv2D::GetKernelW() const { return this->primitive_->value.AsConv2D()->kernelW; } +int Conv2D::GetKernelH() const { return this->primitive_->value.AsConv2D()->kernelH; } +int Conv2D::GetStrideW() const { return this->primitive_->value.AsConv2D()->strideW; } +int Conv2D::GetStrideH() const { return this->primitive_->value.AsConv2D()->strideH; } +int Conv2D::GetPadMode() const { return this->primitive_->value.AsConv2D()->padMode; } +int Conv2D::GetPadUp() const { return this->primitive_->value.AsConv2D()->padUp; } +int Conv2D::GetPadDown() const { return this->primitive_->value.AsConv2D()->padDown; } +int Conv2D::GetPadLeft() const { return this->primitive_->value.AsConv2D()->padLeft; } +int Conv2D::GetPadRight() const { return this->primitive_->value.AsConv2D()->padRight; } +int Conv2D::GetDilateW() const { return this->primitive_->value.AsConv2D()->dilateW; } +int Conv2D::GetDilateH() const { return this->primitive_->value.AsConv2D()->dilateH; } +bool Conv2D::GetHasBias() const { return this->primitive_->value.AsConv2D()->hasBias; } +int Conv2D::GetActivationType() const { return this->primitive_->value.AsConv2D()->activationType; } -void Conv2D::SetFormat(int format) { this->primitive->value.AsConv2D()->format = (schema::Format)format; } -void Conv2D::SetGroup(int group) { this->primitive->value.AsConv2D()->group = group; } -void Conv2D::SetChannelIn(int channel_in) { this->primitive->value.AsConv2D()->channelIn = channel_in; } -void Conv2D::SetChannelOut(int channel_out) { this->primitive->value.AsConv2D()->channelOut = channel_out; } -void Conv2D::SetKernelW(int kernel_w) { this->primitive->value.AsConv2D()->kernelW = kernel_w; } -void Conv2D::SetKernelH(int kernel_h) { this->primitive->value.AsConv2D()->kernelH = kernel_h; } -void Conv2D::SetStrideW(int stride_w) { this->primitive->value.AsConv2D()->strideW = stride_w; } -void Conv2D::SetStrideH(int stride_h) { this->primitive->value.AsConv2D()->strideH = stride_h; } -void Conv2D::SetPadMode(int pad_mode) { this->primitive->value.AsConv2D()->padMode = (schema::PadMode)pad_mode; } -void Conv2D::SetPadUp(int pad_up) { this->primitive->value.AsConv2D()->padUp = pad_up; } -void Conv2D::SetPadDown(int pad_down) { this->primitive->value.AsConv2D()->padDown = pad_down; } -void Conv2D::SetPadLeft(int pad_left) { this->primitive->value.AsConv2D()->padLeft = pad_left; } -void Conv2D::SetPadRight(int pad_right) { this->primitive->value.AsConv2D()->padRight = pad_right; } -void Conv2D::SetDilateW(int dilate_w) { this->primitive->value.AsConv2D()->dilateW = dilate_w; } -void Conv2D::SetDilateH(int dilate_h) { this->primitive->value.AsConv2D()->dilateH = dilate_h; } -void Conv2D::SetHasBias(bool has_bias) { this->primitive->value.AsConv2D()->hasBias = has_bias; } +void Conv2D::SetFormat(int format) { this->primitive_->value.AsConv2D()->format = (schema::Format)format; } +void Conv2D::SetGroup(int group) { this->primitive_->value.AsConv2D()->group = group; } +void Conv2D::SetChannelIn(int channel_in) { this->primitive_->value.AsConv2D()->channelIn = channel_in; } +void Conv2D::SetChannelOut(int channel_out) { this->primitive_->value.AsConv2D()->channelOut = channel_out; } +void Conv2D::SetKernelW(int kernel_w) { this->primitive_->value.AsConv2D()->kernelW = kernel_w; } +void Conv2D::SetKernelH(int kernel_h) { this->primitive_->value.AsConv2D()->kernelH = kernel_h; } +void Conv2D::SetStrideW(int stride_w) { this->primitive_->value.AsConv2D()->strideW = stride_w; } +void Conv2D::SetStrideH(int stride_h) { this->primitive_->value.AsConv2D()->strideH = stride_h; } +void Conv2D::SetPadMode(int pad_mode) { this->primitive_->value.AsConv2D()->padMode = (schema::PadMode)pad_mode; } +void Conv2D::SetPadUp(int pad_up) { this->primitive_->value.AsConv2D()->padUp = pad_up; } +void Conv2D::SetPadDown(int pad_down) { this->primitive_->value.AsConv2D()->padDown = pad_down; } +void Conv2D::SetPadLeft(int pad_left) { this->primitive_->value.AsConv2D()->padLeft = pad_left; } +void Conv2D::SetPadRight(int pad_right) { this->primitive_->value.AsConv2D()->padRight = pad_right; } +void Conv2D::SetDilateW(int dilate_w) { this->primitive_->value.AsConv2D()->dilateW = dilate_w; } +void Conv2D::SetDilateH(int dilate_h) { this->primitive_->value.AsConv2D()->dilateH = dilate_h; } +void Conv2D::SetHasBias(bool has_bias) { this->primitive_->value.AsConv2D()->hasBias = has_bias; } void Conv2D::SetActivationType(int activation_type) { - this->primitive->value.AsConv2D()->activationType = (schema::ActivationType)activation_type; + this->primitive_->value.AsConv2D()->activationType = (schema::ActivationType)activation_type; } #else -int Conv2D::GetFormat() const { return this->primitive->value_as_Conv2D()->format(); } -int Conv2D::GetGroup() const { return this->primitive->value_as_Conv2D()->group(); } -int Conv2D::GetChannelIn() const { return this->primitive->value_as_Conv2D()->channelIn(); } -int Conv2D::GetChannelOut() const { return this->primitive->value_as_Conv2D()->channelOut(); } -int Conv2D::GetKernelW() const { return this->primitive->value_as_Conv2D()->kernelW(); } -int Conv2D::GetKernelH() const { return this->primitive->value_as_Conv2D()->kernelH(); } -int Conv2D::GetStrideW() const { return this->primitive->value_as_Conv2D()->strideW(); } -int Conv2D::GetStrideH() const { return this->primitive->value_as_Conv2D()->strideH(); } -int Conv2D::GetPadMode() const { return this->primitive->value_as_Conv2D()->padMode(); } -int Conv2D::GetPadUp() const { return this->primitive->value_as_Conv2D()->padUp(); } -int Conv2D::GetPadDown() const { return this->primitive->value_as_Conv2D()->padDown(); } -int Conv2D::GetPadLeft() const { return this->primitive->value_as_Conv2D()->padLeft(); } -int Conv2D::GetPadRight() const { return this->primitive->value_as_Conv2D()->padRight(); } -int Conv2D::GetDilateW() const { return this->primitive->value_as_Conv2D()->dilateW(); } -int Conv2D::GetDilateH() const { return this->primitive->value_as_Conv2D()->dilateH(); } -bool Conv2D::GetHasBias() const { return this->primitive->value_as_Conv2D()->hasBias(); } -int Conv2D::GetActivationType() const { return this->primitive->value_as_Conv2D()->activationType(); } +int Conv2D::GetFormat() const { return this->primitive_->value_as_Conv2D()->format(); } +int Conv2D::GetGroup() const { return this->primitive_->value_as_Conv2D()->group(); } +int Conv2D::GetChannelIn() const { return this->primitive_->value_as_Conv2D()->channelIn(); } +int Conv2D::GetChannelOut() const { return this->primitive_->value_as_Conv2D()->channelOut(); } +int Conv2D::GetKernelW() const { return this->primitive_->value_as_Conv2D()->kernelW(); } +int Conv2D::GetKernelH() const { return this->primitive_->value_as_Conv2D()->kernelH(); } +int Conv2D::GetStrideW() const { return this->primitive_->value_as_Conv2D()->strideW(); } +int Conv2D::GetStrideH() const { return this->primitive_->value_as_Conv2D()->strideH(); } +int Conv2D::GetPadMode() const { return this->primitive_->value_as_Conv2D()->padMode(); } +int Conv2D::GetPadUp() const { return this->primitive_->value_as_Conv2D()->padUp(); } +int Conv2D::GetPadDown() const { return this->primitive_->value_as_Conv2D()->padDown(); } +int Conv2D::GetPadLeft() const { return this->primitive_->value_as_Conv2D()->padLeft(); } +int Conv2D::GetPadRight() const { return this->primitive_->value_as_Conv2D()->padRight(); } +int Conv2D::GetDilateW() const { return this->primitive_->value_as_Conv2D()->dilateW(); } +int Conv2D::GetDilateH() const { return this->primitive_->value_as_Conv2D()->dilateH(); } +bool Conv2D::GetHasBias() const { return this->primitive_->value_as_Conv2D()->hasBias(); } +int Conv2D::GetActivationType() const { return this->primitive_->value_as_Conv2D()->activationType(); } void Conv2D::SetFormat(int format) {} void Conv2D::SetGroup(int group) {} @@ -103,7 +103,7 @@ void Conv2D::SetHasBias(bool has_bias) {} void Conv2D::SetActivationType(int activation_type) {} #endif void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output_w) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); int kernel_w = GetKernelW(); int kernel_h = GetKernelH(); int stride_w = GetStrideW(); diff --git a/mindspore/lite/src/ops/conv2d.h b/mindspore/lite/src/ops/conv2d.h index 328fd49718b..2a1b718f3e8 100644 --- a/mindspore/lite/src/ops/conv2d.h +++ b/mindspore/lite/src/ops/conv2d.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_CONV2_D_H_ +#define LITE_MINDSPORE_LITE_C_OPS_CONV2_D_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_CONV2_D_H_ -#define LITE_MINDSPORE_LITE_C_OPS_CONV2_D_H_ namespace mindspore { namespace lite { class Conv2D : public PrimitiveC { public: - explicit Conv2D(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Conv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Conv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int PadUp() const; int PadDown() const; diff --git a/mindspore/lite/src/ops/conv2d_grad_filter.cc b/mindspore/lite/src/ops/conv2d_grad_filter.cc index 49e29c87feb..1fcd9ced90d 100644 --- a/mindspore/lite/src/ops/conv2d_grad_filter.cc +++ b/mindspore/lite/src/ops/conv2d_grad_filter.cc @@ -19,72 +19,74 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Conv2DGradFilter::GetFormat() const { return this->primitive->value.AsConv2DGradFilter()->format; } -int Conv2DGradFilter::GetGroup() const { return this->primitive->value.AsConv2DGradFilter()->group; } -int Conv2DGradFilter::GetChannelIn() const { return this->primitive->value.AsConv2DGradFilter()->channelIn; } -int Conv2DGradFilter::GetChannelOut() const { return this->primitive->value.AsConv2DGradFilter()->channelOut; } -int Conv2DGradFilter::GetKernelW() const { return this->primitive->value.AsConv2DGradFilter()->kernelW; } -int Conv2DGradFilter::GetKernelH() const { return this->primitive->value.AsConv2DGradFilter()->kernelH; } -int Conv2DGradFilter::GetStrideW() const { return this->primitive->value.AsConv2DGradFilter()->strideW; } -int Conv2DGradFilter::GetStrideH() const { return this->primitive->value.AsConv2DGradFilter()->strideH; } -int Conv2DGradFilter::GetPadMode() const { return this->primitive->value.AsConv2DGradFilter()->padMode; } -int Conv2DGradFilter::GetPadUp() const { return this->primitive->value.AsConv2DGradFilter()->padUp; } -int Conv2DGradFilter::GetPadDown() const { return this->primitive->value.AsConv2DGradFilter()->padDown; } -int Conv2DGradFilter::GetPadLeft() const { return this->primitive->value.AsConv2DGradFilter()->padLeft; } -int Conv2DGradFilter::GetPadRight() const { return this->primitive->value.AsConv2DGradFilter()->padRight; } -int Conv2DGradFilter::GetDilateW() const { return this->primitive->value.AsConv2DGradFilter()->dilateW; } -int Conv2DGradFilter::GetDilateH() const { return this->primitive->value.AsConv2DGradFilter()->dilateH; } -bool Conv2DGradFilter::GetHasBias() const { return this->primitive->value.AsConv2DGradFilter()->hasBias; } -int Conv2DGradFilter::GetActivationType() const { return this->primitive->value.AsConv2DGradFilter()->activationType; } +int Conv2DGradFilter::GetFormat() const { return this->primitive_->value.AsConv2DGradFilter()->format; } +int Conv2DGradFilter::GetGroup() const { return this->primitive_->value.AsConv2DGradFilter()->group; } +int Conv2DGradFilter::GetChannelIn() const { return this->primitive_->value.AsConv2DGradFilter()->channelIn; } +int Conv2DGradFilter::GetChannelOut() const { return this->primitive_->value.AsConv2DGradFilter()->channelOut; } +int Conv2DGradFilter::GetKernelW() const { return this->primitive_->value.AsConv2DGradFilter()->kernelW; } +int Conv2DGradFilter::GetKernelH() const { return this->primitive_->value.AsConv2DGradFilter()->kernelH; } +int Conv2DGradFilter::GetStrideW() const { return this->primitive_->value.AsConv2DGradFilter()->strideW; } +int Conv2DGradFilter::GetStrideH() const { return this->primitive_->value.AsConv2DGradFilter()->strideH; } +int Conv2DGradFilter::GetPadMode() const { return this->primitive_->value.AsConv2DGradFilter()->padMode; } +int Conv2DGradFilter::GetPadUp() const { return this->primitive_->value.AsConv2DGradFilter()->padUp; } +int Conv2DGradFilter::GetPadDown() const { return this->primitive_->value.AsConv2DGradFilter()->padDown; } +int Conv2DGradFilter::GetPadLeft() const { return this->primitive_->value.AsConv2DGradFilter()->padLeft; } +int Conv2DGradFilter::GetPadRight() const { return this->primitive_->value.AsConv2DGradFilter()->padRight; } +int Conv2DGradFilter::GetDilateW() const { return this->primitive_->value.AsConv2DGradFilter()->dilateW; } +int Conv2DGradFilter::GetDilateH() const { return this->primitive_->value.AsConv2DGradFilter()->dilateH; } +bool Conv2DGradFilter::GetHasBias() const { return this->primitive_->value.AsConv2DGradFilter()->hasBias; } +int Conv2DGradFilter::GetActivationType() const { return this->primitive_->value.AsConv2DGradFilter()->activationType; } void Conv2DGradFilter::SetFormat(int format) { - this->primitive->value.AsConv2DGradFilter()->format = (schema::Format)format; + this->primitive_->value.AsConv2DGradFilter()->format = (schema::Format)format; } -void Conv2DGradFilter::SetGroup(int group) { this->primitive->value.AsConv2DGradFilter()->group = group; } +void Conv2DGradFilter::SetGroup(int group) { this->primitive_->value.AsConv2DGradFilter()->group = group; } void Conv2DGradFilter::SetChannelIn(int channel_in) { - this->primitive->value.AsConv2DGradFilter()->channelIn = channel_in; + this->primitive_->value.AsConv2DGradFilter()->channelIn = channel_in; } void Conv2DGradFilter::SetChannelOut(int channel_out) { - this->primitive->value.AsConv2DGradFilter()->channelOut = channel_out; + this->primitive_->value.AsConv2DGradFilter()->channelOut = channel_out; } -void Conv2DGradFilter::SetKernelW(int kernel_w) { this->primitive->value.AsConv2DGradFilter()->kernelW = kernel_w; } -void Conv2DGradFilter::SetKernelH(int kernel_h) { this->primitive->value.AsConv2DGradFilter()->kernelH = kernel_h; } -void Conv2DGradFilter::SetStrideW(int stride_w) { this->primitive->value.AsConv2DGradFilter()->strideW = stride_w; } -void Conv2DGradFilter::SetStrideH(int stride_h) { this->primitive->value.AsConv2DGradFilter()->strideH = stride_h; } +void Conv2DGradFilter::SetKernelW(int kernel_w) { this->primitive_->value.AsConv2DGradFilter()->kernelW = kernel_w; } +void Conv2DGradFilter::SetKernelH(int kernel_h) { this->primitive_->value.AsConv2DGradFilter()->kernelH = kernel_h; } +void Conv2DGradFilter::SetStrideW(int stride_w) { this->primitive_->value.AsConv2DGradFilter()->strideW = stride_w; } +void Conv2DGradFilter::SetStrideH(int stride_h) { this->primitive_->value.AsConv2DGradFilter()->strideH = stride_h; } void Conv2DGradFilter::SetPadMode(int pad_mode) { - this->primitive->value.AsConv2DGradFilter()->padMode = (schema::PadMode)pad_mode; + this->primitive_->value.AsConv2DGradFilter()->padMode = (schema::PadMode)pad_mode; } -void Conv2DGradFilter::SetPadUp(int pad_up) { this->primitive->value.AsConv2DGradFilter()->padUp = pad_up; } -void Conv2DGradFilter::SetPadDown(int pad_down) { this->primitive->value.AsConv2DGradFilter()->padDown = pad_down; } -void Conv2DGradFilter::SetPadLeft(int pad_left) { this->primitive->value.AsConv2DGradFilter()->padLeft = pad_left; } -void Conv2DGradFilter::SetPadRight(int pad_right) { this->primitive->value.AsConv2DGradFilter()->padRight = pad_right; } -void Conv2DGradFilter::SetDilateW(int dilate_w) { this->primitive->value.AsConv2DGradFilter()->dilateW = dilate_w; } -void Conv2DGradFilter::SetDilateH(int dilate_h) { this->primitive->value.AsConv2DGradFilter()->dilateH = dilate_h; } -void Conv2DGradFilter::SetHasBias(bool has_bias) { this->primitive->value.AsConv2DGradFilter()->hasBias = has_bias; } +void Conv2DGradFilter::SetPadUp(int pad_up) { this->primitive_->value.AsConv2DGradFilter()->padUp = pad_up; } +void Conv2DGradFilter::SetPadDown(int pad_down) { this->primitive_->value.AsConv2DGradFilter()->padDown = pad_down; } +void Conv2DGradFilter::SetPadLeft(int pad_left) { this->primitive_->value.AsConv2DGradFilter()->padLeft = pad_left; } +void Conv2DGradFilter::SetPadRight(int pad_right) { + this->primitive_->value.AsConv2DGradFilter()->padRight = pad_right; +} +void Conv2DGradFilter::SetDilateW(int dilate_w) { this->primitive_->value.AsConv2DGradFilter()->dilateW = dilate_w; } +void Conv2DGradFilter::SetDilateH(int dilate_h) { this->primitive_->value.AsConv2DGradFilter()->dilateH = dilate_h; } +void Conv2DGradFilter::SetHasBias(bool has_bias) { this->primitive_->value.AsConv2DGradFilter()->hasBias = has_bias; } void Conv2DGradFilter::SetActivationType(int activation_type) { - this->primitive->value.AsConv2DGradFilter()->activationType = (schema::ActivationType)activation_type; + this->primitive_->value.AsConv2DGradFilter()->activationType = (schema::ActivationType)activation_type; } #else -int Conv2DGradFilter::GetFormat() const { return this->primitive->value_as_Conv2DGradFilter()->format(); } -int Conv2DGradFilter::GetGroup() const { return this->primitive->value_as_Conv2DGradFilter()->group(); } -int Conv2DGradFilter::GetChannelIn() const { return this->primitive->value_as_Conv2DGradFilter()->channelIn(); } -int Conv2DGradFilter::GetChannelOut() const { return this->primitive->value_as_Conv2DGradFilter()->channelOut(); } -int Conv2DGradFilter::GetKernelW() const { return this->primitive->value_as_Conv2DGradFilter()->kernelW(); } -int Conv2DGradFilter::GetKernelH() const { return this->primitive->value_as_Conv2DGradFilter()->kernelH(); } -int Conv2DGradFilter::GetStrideW() const { return this->primitive->value_as_Conv2DGradFilter()->strideW(); } -int Conv2DGradFilter::GetStrideH() const { return this->primitive->value_as_Conv2DGradFilter()->strideH(); } -int Conv2DGradFilter::GetPadMode() const { return this->primitive->value_as_Conv2DGradFilter()->padMode(); } -int Conv2DGradFilter::GetPadUp() const { return this->primitive->value_as_Conv2DGradFilter()->padUp(); } -int Conv2DGradFilter::GetPadDown() const { return this->primitive->value_as_Conv2DGradFilter()->padDown(); } -int Conv2DGradFilter::GetPadLeft() const { return this->primitive->value_as_Conv2DGradFilter()->padLeft(); } -int Conv2DGradFilter::GetPadRight() const { return this->primitive->value_as_Conv2DGradFilter()->padRight(); } -int Conv2DGradFilter::GetDilateW() const { return this->primitive->value_as_Conv2DGradFilter()->dilateW(); } -int Conv2DGradFilter::GetDilateH() const { return this->primitive->value_as_Conv2DGradFilter()->dilateH(); } -bool Conv2DGradFilter::GetHasBias() const { return this->primitive->value_as_Conv2DGradFilter()->hasBias(); } +int Conv2DGradFilter::GetFormat() const { return this->primitive_->value_as_Conv2DGradFilter()->format(); } +int Conv2DGradFilter::GetGroup() const { return this->primitive_->value_as_Conv2DGradFilter()->group(); } +int Conv2DGradFilter::GetChannelIn() const { return this->primitive_->value_as_Conv2DGradFilter()->channelIn(); } +int Conv2DGradFilter::GetChannelOut() const { return this->primitive_->value_as_Conv2DGradFilter()->channelOut(); } +int Conv2DGradFilter::GetKernelW() const { return this->primitive_->value_as_Conv2DGradFilter()->kernelW(); } +int Conv2DGradFilter::GetKernelH() const { return this->primitive_->value_as_Conv2DGradFilter()->kernelH(); } +int Conv2DGradFilter::GetStrideW() const { return this->primitive_->value_as_Conv2DGradFilter()->strideW(); } +int Conv2DGradFilter::GetStrideH() const { return this->primitive_->value_as_Conv2DGradFilter()->strideH(); } +int Conv2DGradFilter::GetPadMode() const { return this->primitive_->value_as_Conv2DGradFilter()->padMode(); } +int Conv2DGradFilter::GetPadUp() const { return this->primitive_->value_as_Conv2DGradFilter()->padUp(); } +int Conv2DGradFilter::GetPadDown() const { return this->primitive_->value_as_Conv2DGradFilter()->padDown(); } +int Conv2DGradFilter::GetPadLeft() const { return this->primitive_->value_as_Conv2DGradFilter()->padLeft(); } +int Conv2DGradFilter::GetPadRight() const { return this->primitive_->value_as_Conv2DGradFilter()->padRight(); } +int Conv2DGradFilter::GetDilateW() const { return this->primitive_->value_as_Conv2DGradFilter()->dilateW(); } +int Conv2DGradFilter::GetDilateH() const { return this->primitive_->value_as_Conv2DGradFilter()->dilateH(); } +bool Conv2DGradFilter::GetHasBias() const { return this->primitive_->value_as_Conv2DGradFilter()->hasBias(); } int Conv2DGradFilter::GetActivationType() const { - return this->primitive->value_as_Conv2DGradFilter()->activationType(); + return this->primitive_->value_as_Conv2DGradFilter()->activationType(); } void Conv2DGradFilter::SetFormat(int format) {} diff --git a/mindspore/lite/src/ops/conv2d_grad_filter.h b/mindspore/lite/src/ops/conv2d_grad_filter.h index 7094983ee02..7ed1d696611 100644 --- a/mindspore/lite/src/ops/conv2d_grad_filter.h +++ b/mindspore/lite/src/ops/conv2d_grad_filter.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_FILTER_H_ +#define LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_FILTER_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_FILTER_H_ -#define LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_FILTER_H_ namespace mindspore { namespace lite { class Conv2DGradFilter : public PrimitiveC { public: - explicit Conv2DGradFilter(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Conv2DGradFilter(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Conv2DGradFilter(schema::Primitive *primitive) : PrimitiveC(primitive) {} int GetFormat() const; int GetGroup() const; diff --git a/mindspore/lite/src/ops/conv2d_grad_input.cc b/mindspore/lite/src/ops/conv2d_grad_input.cc index 9892f98a63d..28a66d2e3c1 100644 --- a/mindspore/lite/src/ops/conv2d_grad_input.cc +++ b/mindspore/lite/src/ops/conv2d_grad_input.cc @@ -19,71 +19,73 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Conv2DGradInput::GetFormat() const { return this->primitive->value.AsConv2DGradInput()->format; } -int Conv2DGradInput::GetGroup() const { return this->primitive->value.AsConv2DGradInput()->group; } -int Conv2DGradInput::GetChannelIn() const { return this->primitive->value.AsConv2DGradInput()->channelIn; } -int Conv2DGradInput::GetChannelOut() const { return this->primitive->value.AsConv2DGradInput()->channelOut; } -int Conv2DGradInput::GetKernelW() const { return this->primitive->value.AsConv2DGradInput()->kernelW; } -int Conv2DGradInput::GetKernelH() const { return this->primitive->value.AsConv2DGradInput()->kernelH; } -int Conv2DGradInput::GetStrideW() const { return this->primitive->value.AsConv2DGradInput()->strideW; } -int Conv2DGradInput::GetStrideH() const { return this->primitive->value.AsConv2DGradInput()->strideH; } -int Conv2DGradInput::GetPadMode() const { return this->primitive->value.AsConv2DGradInput()->padMode; } -int Conv2DGradInput::GetPadUp() const { return this->primitive->value.AsConv2DGradInput()->padUp; } -int Conv2DGradInput::GetPadDown() const { return this->primitive->value.AsConv2DGradInput()->padDown; } -int Conv2DGradInput::GetPadLeft() const { return this->primitive->value.AsConv2DGradInput()->padLeft; } -int Conv2DGradInput::GetPadRight() const { return this->primitive->value.AsConv2DGradInput()->padRight; } -int Conv2DGradInput::GetDilateW() const { return this->primitive->value.AsConv2DGradInput()->dilateW; } -int Conv2DGradInput::GetDilateH() const { return this->primitive->value.AsConv2DGradInput()->dilateH; } -bool Conv2DGradInput::GetHasBias() const { return this->primitive->value.AsConv2DGradInput()->hasBias; } -int Conv2DGradInput::GetActivationType() const { return this->primitive->value.AsConv2DGradInput()->activationType; } +int Conv2DGradInput::GetFormat() const { return this->primitive_->value.AsConv2DGradInput()->format; } +int Conv2DGradInput::GetGroup() const { return this->primitive_->value.AsConv2DGradInput()->group; } +int Conv2DGradInput::GetChannelIn() const { return this->primitive_->value.AsConv2DGradInput()->channelIn; } +int Conv2DGradInput::GetChannelOut() const { return this->primitive_->value.AsConv2DGradInput()->channelOut; } +int Conv2DGradInput::GetKernelW() const { return this->primitive_->value.AsConv2DGradInput()->kernelW; } +int Conv2DGradInput::GetKernelH() const { return this->primitive_->value.AsConv2DGradInput()->kernelH; } +int Conv2DGradInput::GetStrideW() const { return this->primitive_->value.AsConv2DGradInput()->strideW; } +int Conv2DGradInput::GetStrideH() const { return this->primitive_->value.AsConv2DGradInput()->strideH; } +int Conv2DGradInput::GetPadMode() const { return this->primitive_->value.AsConv2DGradInput()->padMode; } +int Conv2DGradInput::GetPadUp() const { return this->primitive_->value.AsConv2DGradInput()->padUp; } +int Conv2DGradInput::GetPadDown() const { return this->primitive_->value.AsConv2DGradInput()->padDown; } +int Conv2DGradInput::GetPadLeft() const { return this->primitive_->value.AsConv2DGradInput()->padLeft; } +int Conv2DGradInput::GetPadRight() const { return this->primitive_->value.AsConv2DGradInput()->padRight; } +int Conv2DGradInput::GetDilateW() const { return this->primitive_->value.AsConv2DGradInput()->dilateW; } +int Conv2DGradInput::GetDilateH() const { return this->primitive_->value.AsConv2DGradInput()->dilateH; } +bool Conv2DGradInput::GetHasBias() const { return this->primitive_->value.AsConv2DGradInput()->hasBias; } +int Conv2DGradInput::GetActivationType() const { return this->primitive_->value.AsConv2DGradInput()->activationType; } void Conv2DGradInput::SetFormat(int format) { - this->primitive->value.AsConv2DGradInput()->format = (schema::Format)format; + this->primitive_->value.AsConv2DGradInput()->format = (schema::Format)format; } -void Conv2DGradInput::SetGroup(int group) { this->primitive->value.AsConv2DGradInput()->group = group; } +void Conv2DGradInput::SetGroup(int group) { this->primitive_->value.AsConv2DGradInput()->group = group; } void Conv2DGradInput::SetChannelIn(int channel_in) { - this->primitive->value.AsConv2DGradInput()->channelIn = channel_in; + this->primitive_->value.AsConv2DGradInput()->channelIn = channel_in; } void Conv2DGradInput::SetChannelOut(int channel_out) { - this->primitive->value.AsConv2DGradInput()->channelOut = channel_out; + this->primitive_->value.AsConv2DGradInput()->channelOut = channel_out; } -void Conv2DGradInput::SetKernelW(int kernel_w) { this->primitive->value.AsConv2DGradInput()->kernelW = kernel_w; } -void Conv2DGradInput::SetKernelH(int kernel_h) { this->primitive->value.AsConv2DGradInput()->kernelH = kernel_h; } -void Conv2DGradInput::SetStrideW(int stride_w) { this->primitive->value.AsConv2DGradInput()->strideW = stride_w; } -void Conv2DGradInput::SetStrideH(int stride_h) { this->primitive->value.AsConv2DGradInput()->strideH = stride_h; } +void Conv2DGradInput::SetKernelW(int kernel_w) { this->primitive_->value.AsConv2DGradInput()->kernelW = kernel_w; } +void Conv2DGradInput::SetKernelH(int kernel_h) { this->primitive_->value.AsConv2DGradInput()->kernelH = kernel_h; } +void Conv2DGradInput::SetStrideW(int stride_w) { this->primitive_->value.AsConv2DGradInput()->strideW = stride_w; } +void Conv2DGradInput::SetStrideH(int stride_h) { this->primitive_->value.AsConv2DGradInput()->strideH = stride_h; } void Conv2DGradInput::SetPadMode(int pad_mode) { - this->primitive->value.AsConv2DGradInput()->padMode = (schema::PadMode)pad_mode; + this->primitive_->value.AsConv2DGradInput()->padMode = (schema::PadMode)pad_mode; } -void Conv2DGradInput::SetPadUp(int pad_up) { this->primitive->value.AsConv2DGradInput()->padUp = pad_up; } -void Conv2DGradInput::SetPadDown(int pad_down) { this->primitive->value.AsConv2DGradInput()->padDown = pad_down; } -void Conv2DGradInput::SetPadLeft(int pad_left) { this->primitive->value.AsConv2DGradInput()->padLeft = pad_left; } -void Conv2DGradInput::SetPadRight(int pad_right) { this->primitive->value.AsConv2DGradInput()->padRight = pad_right; } -void Conv2DGradInput::SetDilateW(int dilate_w) { this->primitive->value.AsConv2DGradInput()->dilateW = dilate_w; } -void Conv2DGradInput::SetDilateH(int dilate_h) { this->primitive->value.AsConv2DGradInput()->dilateH = dilate_h; } -void Conv2DGradInput::SetHasBias(bool has_bias) { this->primitive->value.AsConv2DGradInput()->hasBias = has_bias; } +void Conv2DGradInput::SetPadUp(int pad_up) { this->primitive_->value.AsConv2DGradInput()->padUp = pad_up; } +void Conv2DGradInput::SetPadDown(int pad_down) { this->primitive_->value.AsConv2DGradInput()->padDown = pad_down; } +void Conv2DGradInput::SetPadLeft(int pad_left) { this->primitive_->value.AsConv2DGradInput()->padLeft = pad_left; } +void Conv2DGradInput::SetPadRight(int pad_right) { this->primitive_->value.AsConv2DGradInput()->padRight = pad_right; } +void Conv2DGradInput::SetDilateW(int dilate_w) { this->primitive_->value.AsConv2DGradInput()->dilateW = dilate_w; } +void Conv2DGradInput::SetDilateH(int dilate_h) { this->primitive_->value.AsConv2DGradInput()->dilateH = dilate_h; } +void Conv2DGradInput::SetHasBias(bool has_bias) { this->primitive_->value.AsConv2DGradInput()->hasBias = has_bias; } void Conv2DGradInput::SetActivationType(int activation_type) { - this->primitive->value.AsConv2DGradInput()->activationType = (schema::ActivationType)activation_type; + this->primitive_->value.AsConv2DGradInput()->activationType = (schema::ActivationType)activation_type; } #else -int Conv2DGradInput::GetFormat() const { return this->primitive->value_as_Conv2DGradInput()->format(); } -int Conv2DGradInput::GetGroup() const { return this->primitive->value_as_Conv2DGradInput()->group(); } -int Conv2DGradInput::GetChannelIn() const { return this->primitive->value_as_Conv2DGradInput()->channelIn(); } -int Conv2DGradInput::GetChannelOut() const { return this->primitive->value_as_Conv2DGradInput()->channelOut(); } -int Conv2DGradInput::GetKernelW() const { return this->primitive->value_as_Conv2DGradInput()->kernelW(); } -int Conv2DGradInput::GetKernelH() const { return this->primitive->value_as_Conv2DGradInput()->kernelH(); } -int Conv2DGradInput::GetStrideW() const { return this->primitive->value_as_Conv2DGradInput()->strideW(); } -int Conv2DGradInput::GetStrideH() const { return this->primitive->value_as_Conv2DGradInput()->strideH(); } -int Conv2DGradInput::GetPadMode() const { return this->primitive->value_as_Conv2DGradInput()->padMode(); } -int Conv2DGradInput::GetPadUp() const { return this->primitive->value_as_Conv2DGradInput()->padUp(); } -int Conv2DGradInput::GetPadDown() const { return this->primitive->value_as_Conv2DGradInput()->padDown(); } -int Conv2DGradInput::GetPadLeft() const { return this->primitive->value_as_Conv2DGradInput()->padLeft(); } -int Conv2DGradInput::GetPadRight() const { return this->primitive->value_as_Conv2DGradInput()->padRight(); } -int Conv2DGradInput::GetDilateW() const { return this->primitive->value_as_Conv2DGradInput()->dilateW(); } -int Conv2DGradInput::GetDilateH() const { return this->primitive->value_as_Conv2DGradInput()->dilateH(); } -bool Conv2DGradInput::GetHasBias() const { return this->primitive->value_as_Conv2DGradInput()->hasBias(); } -int Conv2DGradInput::GetActivationType() const { return this->primitive->value_as_Conv2DGradInput()->activationType(); } +int Conv2DGradInput::GetFormat() const { return this->primitive_->value_as_Conv2DGradInput()->format(); } +int Conv2DGradInput::GetGroup() const { return this->primitive_->value_as_Conv2DGradInput()->group(); } +int Conv2DGradInput::GetChannelIn() const { return this->primitive_->value_as_Conv2DGradInput()->channelIn(); } +int Conv2DGradInput::GetChannelOut() const { return this->primitive_->value_as_Conv2DGradInput()->channelOut(); } +int Conv2DGradInput::GetKernelW() const { return this->primitive_->value_as_Conv2DGradInput()->kernelW(); } +int Conv2DGradInput::GetKernelH() const { return this->primitive_->value_as_Conv2DGradInput()->kernelH(); } +int Conv2DGradInput::GetStrideW() const { return this->primitive_->value_as_Conv2DGradInput()->strideW(); } +int Conv2DGradInput::GetStrideH() const { return this->primitive_->value_as_Conv2DGradInput()->strideH(); } +int Conv2DGradInput::GetPadMode() const { return this->primitive_->value_as_Conv2DGradInput()->padMode(); } +int Conv2DGradInput::GetPadUp() const { return this->primitive_->value_as_Conv2DGradInput()->padUp(); } +int Conv2DGradInput::GetPadDown() const { return this->primitive_->value_as_Conv2DGradInput()->padDown(); } +int Conv2DGradInput::GetPadLeft() const { return this->primitive_->value_as_Conv2DGradInput()->padLeft(); } +int Conv2DGradInput::GetPadRight() const { return this->primitive_->value_as_Conv2DGradInput()->padRight(); } +int Conv2DGradInput::GetDilateW() const { return this->primitive_->value_as_Conv2DGradInput()->dilateW(); } +int Conv2DGradInput::GetDilateH() const { return this->primitive_->value_as_Conv2DGradInput()->dilateH(); } +bool Conv2DGradInput::GetHasBias() const { return this->primitive_->value_as_Conv2DGradInput()->hasBias(); } +int Conv2DGradInput::GetActivationType() const { + return this->primitive_->value_as_Conv2DGradInput()->activationType(); +} void Conv2DGradInput::SetFormat(int format) {} void Conv2DGradInput::SetGroup(int group) {} diff --git a/mindspore/lite/src/ops/conv2d_grad_input.h b/mindspore/lite/src/ops/conv2d_grad_input.h index 4c06d0b858a..7c71d6d2095 100644 --- a/mindspore/lite/src/ops/conv2d_grad_input.h +++ b/mindspore/lite/src/ops/conv2d_grad_input.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_INPUT_H_ +#define LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_INPUT_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_INPUT_H_ -#define LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_INPUT_H_ namespace mindspore { namespace lite { class Conv2DGradInput : public PrimitiveC { public: - explicit Conv2DGradInput(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Conv2DGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Conv2DGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {} int GetFormat() const; int GetGroup() const; diff --git a/mindspore/lite/src/ops/cos.h b/mindspore/lite/src/ops/cos.h index 25f199c104f..1cc39df284a 100644 --- a/mindspore/lite/src/ops/cos.h +++ b/mindspore/lite/src/ops/cos.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_COS_H_ +#define LITE_MINDSPORE_LITE_C_OPS_COS_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/arithmetic_self.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_COS_H_ -#define LITE_MINDSPORE_LITE_C_OPS_COS_H_ +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { class Cos : public ArithmeticSelf { public: - explicit Cos(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Cos(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} +#endif + explicit Cos(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/crop.cc b/mindspore/lite/src/ops/crop.cc index 7e2c32167c8..1b74a9bb825 100644 --- a/mindspore/lite/src/ops/crop.cc +++ b/mindspore/lite/src/ops/crop.cc @@ -19,17 +19,17 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -long Crop::GetAxis() const { return this->primitive->value.AsCrop()->axis; } -std::vector Crop::GetOffsets() const { return this->primitive->value.AsCrop()->offsets; } +long Crop::GetAxis() const { return this->primitive_->value.AsCrop()->axis; } +std::vector Crop::GetOffsets() const { return this->primitive_->value.AsCrop()->offsets; } -void Crop::SetAxis(long axis) { this->primitive->value.AsCrop()->axis = axis; } -void Crop::SetOffsets(const std::vector &offsets) { this->primitive->value.AsCrop()->offsets = offsets; } +void Crop::SetAxis(long axis) { this->primitive_->value.AsCrop()->axis = axis; } +void Crop::SetOffsets(const std::vector &offsets) { this->primitive_->value.AsCrop()->offsets = offsets; } #else -long Crop::GetAxis() const { return this->primitive->value_as_Crop()->axis(); } +long Crop::GetAxis() const { return this->primitive_->value_as_Crop()->axis(); } std::vector Crop::GetOffsets() const { - auto fb_vector = this->primitive->value_as_Crop()->offsets(); + auto fb_vector = this->primitive_->value_as_Crop()->offsets(); return std::vector(fb_vector->begin(), fb_vector->end()); } diff --git a/mindspore/lite/src/ops/crop.h b/mindspore/lite/src/ops/crop.h index 87a6fb2dbdc..c47402c043b 100644 --- a/mindspore/lite/src/ops/crop.h +++ b/mindspore/lite/src/ops/crop.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_CROP_H_ +#define LITE_MINDSPORE_LITE_C_OPS_CROP_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_CROP_H_ -#define LITE_MINDSPORE_LITE_C_OPS_CROP_H_ namespace mindspore { namespace lite { class Crop : public PrimitiveC { public: - explicit Crop(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Crop(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Crop(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; long GetAxis() const; diff --git a/mindspore/lite/src/ops/deconv2d.cc b/mindspore/lite/src/ops/deconv2d.cc index 91a5f31a869..1e930290741 100644 --- a/mindspore/lite/src/ops/deconv2d.cc +++ b/mindspore/lite/src/ops/deconv2d.cc @@ -19,63 +19,63 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int DeConv2D::GetFormat() const { return this->primitive->value.AsDeConv2D()->format; } -int DeConv2D::GetGroup() const { return this->primitive->value.AsDeConv2D()->group; } -int DeConv2D::GetChannelIn() const { return this->primitive->value.AsDeConv2D()->channelIn; } -int DeConv2D::GetChannelOut() const { return this->primitive->value.AsDeConv2D()->channelOut; } -int DeConv2D::GetKernelW() const { return this->primitive->value.AsDeConv2D()->kernelW; } -int DeConv2D::GetKernelH() const { return this->primitive->value.AsDeConv2D()->kernelH; } -int DeConv2D::GetStrideW() const { return this->primitive->value.AsDeConv2D()->strideW; } -int DeConv2D::GetStrideH() const { return this->primitive->value.AsDeConv2D()->strideH; } -int DeConv2D::GetPadMode() const { return this->primitive->value.AsDeConv2D()->padMode; } -int DeConv2D::GetPadUp() const { return this->primitive->value.AsDeConv2D()->padUp; } -int DeConv2D::GetPadDown() const { return this->primitive->value.AsDeConv2D()->padDown; } -int DeConv2D::GetPadLeft() const { return this->primitive->value.AsDeConv2D()->padLeft; } -int DeConv2D::GetPadRight() const { return this->primitive->value.AsDeConv2D()->padRight; } -int DeConv2D::GetDilateW() const { return this->primitive->value.AsDeConv2D()->dilateW; } -int DeConv2D::GetDilateH() const { return this->primitive->value.AsDeConv2D()->dilateH; } -bool DeConv2D::GetHasBias() const { return this->primitive->value.AsDeConv2D()->hasBias; } -int DeConv2D::GetActivationType() const { return this->primitive->value.AsDeConv2D()->activationType; } +int DeConv2D::GetFormat() const { return this->primitive_->value.AsDeConv2D()->format; } +int DeConv2D::GetGroup() const { return this->primitive_->value.AsDeConv2D()->group; } +int DeConv2D::GetChannelIn() const { return this->primitive_->value.AsDeConv2D()->channelIn; } +int DeConv2D::GetChannelOut() const { return this->primitive_->value.AsDeConv2D()->channelOut; } +int DeConv2D::GetKernelW() const { return this->primitive_->value.AsDeConv2D()->kernelW; } +int DeConv2D::GetKernelH() const { return this->primitive_->value.AsDeConv2D()->kernelH; } +int DeConv2D::GetStrideW() const { return this->primitive_->value.AsDeConv2D()->strideW; } +int DeConv2D::GetStrideH() const { return this->primitive_->value.AsDeConv2D()->strideH; } +int DeConv2D::GetPadMode() const { return this->primitive_->value.AsDeConv2D()->padMode; } +int DeConv2D::GetPadUp() const { return this->primitive_->value.AsDeConv2D()->padUp; } +int DeConv2D::GetPadDown() const { return this->primitive_->value.AsDeConv2D()->padDown; } +int DeConv2D::GetPadLeft() const { return this->primitive_->value.AsDeConv2D()->padLeft; } +int DeConv2D::GetPadRight() const { return this->primitive_->value.AsDeConv2D()->padRight; } +int DeConv2D::GetDilateW() const { return this->primitive_->value.AsDeConv2D()->dilateW; } +int DeConv2D::GetDilateH() const { return this->primitive_->value.AsDeConv2D()->dilateH; } +bool DeConv2D::GetHasBias() const { return this->primitive_->value.AsDeConv2D()->hasBias; } +int DeConv2D::GetActivationType() const { return this->primitive_->value.AsDeConv2D()->activationType; } -void DeConv2D::SetFormat(int format) { this->primitive->value.AsDeConv2D()->format = (schema::Format)format; } -void DeConv2D::SetGroup(int group) { this->primitive->value.AsDeConv2D()->group = group; } -void DeConv2D::SetChannelIn(int channel_in) { this->primitive->value.AsDeConv2D()->channelIn = channel_in; } -void DeConv2D::SetChannelOut(int channel_out) { this->primitive->value.AsDeConv2D()->channelOut = channel_out; } -void DeConv2D::SetKernelW(int kernel_w) { this->primitive->value.AsDeConv2D()->kernelW = kernel_w; } -void DeConv2D::SetKernelH(int kernel_h) { this->primitive->value.AsDeConv2D()->kernelH = kernel_h; } -void DeConv2D::SetStrideW(int stride_w) { this->primitive->value.AsDeConv2D()->strideW = stride_w; } -void DeConv2D::SetStrideH(int stride_h) { this->primitive->value.AsDeConv2D()->strideH = stride_h; } -void DeConv2D::SetPadMode(int pad_mode) { this->primitive->value.AsDeConv2D()->padMode = (schema::PadMode)pad_mode; } -void DeConv2D::SetPadUp(int pad_up) { this->primitive->value.AsDeConv2D()->padUp = pad_up; } -void DeConv2D::SetPadDown(int pad_down) { this->primitive->value.AsDeConv2D()->padDown = pad_down; } -void DeConv2D::SetPadLeft(int pad_left) { this->primitive->value.AsDeConv2D()->padLeft = pad_left; } -void DeConv2D::SetPadRight(int pad_right) { this->primitive->value.AsDeConv2D()->padRight = pad_right; } -void DeConv2D::SetDilateW(int dilate_w) { this->primitive->value.AsDeConv2D()->dilateW = dilate_w; } -void DeConv2D::SetDilateH(int dilate_h) { this->primitive->value.AsDeConv2D()->dilateH = dilate_h; } -void DeConv2D::SetHasBias(bool has_bias) { this->primitive->value.AsDeConv2D()->hasBias = has_bias; } +void DeConv2D::SetFormat(int format) { this->primitive_->value.AsDeConv2D()->format = (schema::Format)format; } +void DeConv2D::SetGroup(int group) { this->primitive_->value.AsDeConv2D()->group = group; } +void DeConv2D::SetChannelIn(int channel_in) { this->primitive_->value.AsDeConv2D()->channelIn = channel_in; } +void DeConv2D::SetChannelOut(int channel_out) { this->primitive_->value.AsDeConv2D()->channelOut = channel_out; } +void DeConv2D::SetKernelW(int kernel_w) { this->primitive_->value.AsDeConv2D()->kernelW = kernel_w; } +void DeConv2D::SetKernelH(int kernel_h) { this->primitive_->value.AsDeConv2D()->kernelH = kernel_h; } +void DeConv2D::SetStrideW(int stride_w) { this->primitive_->value.AsDeConv2D()->strideW = stride_w; } +void DeConv2D::SetStrideH(int stride_h) { this->primitive_->value.AsDeConv2D()->strideH = stride_h; } +void DeConv2D::SetPadMode(int pad_mode) { this->primitive_->value.AsDeConv2D()->padMode = (schema::PadMode)pad_mode; } +void DeConv2D::SetPadUp(int pad_up) { this->primitive_->value.AsDeConv2D()->padUp = pad_up; } +void DeConv2D::SetPadDown(int pad_down) { this->primitive_->value.AsDeConv2D()->padDown = pad_down; } +void DeConv2D::SetPadLeft(int pad_left) { this->primitive_->value.AsDeConv2D()->padLeft = pad_left; } +void DeConv2D::SetPadRight(int pad_right) { this->primitive_->value.AsDeConv2D()->padRight = pad_right; } +void DeConv2D::SetDilateW(int dilate_w) { this->primitive_->value.AsDeConv2D()->dilateW = dilate_w; } +void DeConv2D::SetDilateH(int dilate_h) { this->primitive_->value.AsDeConv2D()->dilateH = dilate_h; } +void DeConv2D::SetHasBias(bool has_bias) { this->primitive_->value.AsDeConv2D()->hasBias = has_bias; } void DeConv2D::SetActivationType(int activation_type) { - this->primitive->value.AsDeConv2D()->activationType = (schema::ActivationType)activation_type; + this->primitive_->value.AsDeConv2D()->activationType = (schema::ActivationType)activation_type; } #else -int DeConv2D::GetFormat() const { return this->primitive->value_as_DeConv2D()->format(); } -int DeConv2D::GetGroup() const { return this->primitive->value_as_DeConv2D()->group(); } -int DeConv2D::GetChannelIn() const { return this->primitive->value_as_DeConv2D()->channelIn(); } -int DeConv2D::GetChannelOut() const { return this->primitive->value_as_DeConv2D()->channelOut(); } -int DeConv2D::GetKernelW() const { return this->primitive->value_as_DeConv2D()->kernelW(); } -int DeConv2D::GetKernelH() const { return this->primitive->value_as_DeConv2D()->kernelH(); } -int DeConv2D::GetStrideW() const { return this->primitive->value_as_DeConv2D()->strideW(); } -int DeConv2D::GetStrideH() const { return this->primitive->value_as_DeConv2D()->strideH(); } -int DeConv2D::GetPadMode() const { return this->primitive->value_as_DeConv2D()->padMode(); } -int DeConv2D::GetPadUp() const { return this->primitive->value_as_DeConv2D()->padUp(); } -int DeConv2D::GetPadDown() const { return this->primitive->value_as_DeConv2D()->padDown(); } -int DeConv2D::GetPadLeft() const { return this->primitive->value_as_DeConv2D()->padLeft(); } -int DeConv2D::GetPadRight() const { return this->primitive->value_as_DeConv2D()->padRight(); } -int DeConv2D::GetDilateW() const { return this->primitive->value_as_DeConv2D()->dilateW(); } -int DeConv2D::GetDilateH() const { return this->primitive->value_as_DeConv2D()->dilateH(); } -bool DeConv2D::GetHasBias() const { return this->primitive->value_as_DeConv2D()->hasBias(); } -int DeConv2D::GetActivationType() const { return this->primitive->value_as_DeConv2D()->activationType(); } +int DeConv2D::GetFormat() const { return this->primitive_->value_as_DeConv2D()->format(); } +int DeConv2D::GetGroup() const { return this->primitive_->value_as_DeConv2D()->group(); } +int DeConv2D::GetChannelIn() const { return this->primitive_->value_as_DeConv2D()->channelIn(); } +int DeConv2D::GetChannelOut() const { return this->primitive_->value_as_DeConv2D()->channelOut(); } +int DeConv2D::GetKernelW() const { return this->primitive_->value_as_DeConv2D()->kernelW(); } +int DeConv2D::GetKernelH() const { return this->primitive_->value_as_DeConv2D()->kernelH(); } +int DeConv2D::GetStrideW() const { return this->primitive_->value_as_DeConv2D()->strideW(); } +int DeConv2D::GetStrideH() const { return this->primitive_->value_as_DeConv2D()->strideH(); } +int DeConv2D::GetPadMode() const { return this->primitive_->value_as_DeConv2D()->padMode(); } +int DeConv2D::GetPadUp() const { return this->primitive_->value_as_DeConv2D()->padUp(); } +int DeConv2D::GetPadDown() const { return this->primitive_->value_as_DeConv2D()->padDown(); } +int DeConv2D::GetPadLeft() const { return this->primitive_->value_as_DeConv2D()->padLeft(); } +int DeConv2D::GetPadRight() const { return this->primitive_->value_as_DeConv2D()->padRight(); } +int DeConv2D::GetDilateW() const { return this->primitive_->value_as_DeConv2D()->dilateW(); } +int DeConv2D::GetDilateH() const { return this->primitive_->value_as_DeConv2D()->dilateH(); } +bool DeConv2D::GetHasBias() const { return this->primitive_->value_as_DeConv2D()->hasBias(); } +int DeConv2D::GetActivationType() const { return this->primitive_->value_as_DeConv2D()->activationType(); } void DeConv2D::SetFormat(int format) {} void DeConv2D::SetGroup(int group) {} @@ -96,7 +96,7 @@ void DeConv2D::SetHasBias(bool has_bias) {} void DeConv2D::SetActivationType(int activation_type) {} #endif int DeConv2D::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto weight = inputs_.at(1); diff --git a/mindspore/lite/src/ops/deconv2d.h b/mindspore/lite/src/ops/deconv2d.h index de20b0ab732..0fe42927f00 100644 --- a/mindspore/lite/src/ops/deconv2d.h +++ b/mindspore/lite/src/ops/deconv2d.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_DE_CONV2_D_H_ +#define LITE_MINDSPORE_LITE_C_OPS_DE_CONV2_D_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_DE_CONV2_D_H_ -#define LITE_MINDSPORE_LITE_C_OPS_DE_CONV2_D_H_ namespace mindspore { namespace lite { class DeConv2D : public PrimitiveC { public: - explicit DeConv2D(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit DeConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit DeConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; diff --git a/mindspore/lite/src/ops/dedepthwise_conv2d.cc b/mindspore/lite/src/ops/dedepthwise_conv2d.cc index 8c63bb0ea1e..7fdcf54d513 100644 --- a/mindspore/lite/src/ops/dedepthwise_conv2d.cc +++ b/mindspore/lite/src/ops/dedepthwise_conv2d.cc @@ -19,77 +19,77 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int DeDepthwiseConv2D::GetFormat() const { return this->primitive->value.AsDeDepthwiseConv2D()->format; } -int DeDepthwiseConv2D::GetChannelIn() const { return this->primitive->value.AsDeDepthwiseConv2D()->channelIn; } +int DeDepthwiseConv2D::GetFormat() const { return this->primitive_->value.AsDeDepthwiseConv2D()->format; } +int DeDepthwiseConv2D::GetChannelIn() const { return this->primitive_->value.AsDeDepthwiseConv2D()->channelIn; } int DeDepthwiseConv2D::GetChannelMultiplier() const { - return this->primitive->value.AsDeDepthwiseConv2D()->channelMultiplier; + return this->primitive_->value.AsDeDepthwiseConv2D()->channelMultiplier; } -int DeDepthwiseConv2D::GetKernelW() const { return this->primitive->value.AsDeDepthwiseConv2D()->kernelW; } -int DeDepthwiseConv2D::GetKernelH() const { return this->primitive->value.AsDeDepthwiseConv2D()->kernelH; } -int DeDepthwiseConv2D::GetStrideW() const { return this->primitive->value.AsDeDepthwiseConv2D()->strideW; } -int DeDepthwiseConv2D::GetStrideH() const { return this->primitive->value.AsDeDepthwiseConv2D()->strideH; } -int DeDepthwiseConv2D::GetPadMode() const { return this->primitive->value.AsDeDepthwiseConv2D()->padMode; } -int DeDepthwiseConv2D::GetPadUp() const { return this->primitive->value.AsDeDepthwiseConv2D()->padUp; } -int DeDepthwiseConv2D::GetPadDown() const { return this->primitive->value.AsDeDepthwiseConv2D()->padDown; } -int DeDepthwiseConv2D::GetPadLeft() const { return this->primitive->value.AsDeDepthwiseConv2D()->padLeft; } -int DeDepthwiseConv2D::GetPadRight() const { return this->primitive->value.AsDeDepthwiseConv2D()->padRight; } -int DeDepthwiseConv2D::GetDilateW() const { return this->primitive->value.AsDeDepthwiseConv2D()->dilateW; } -int DeDepthwiseConv2D::GetDilateH() const { return this->primitive->value.AsDeDepthwiseConv2D()->dilateH; } -bool DeDepthwiseConv2D::GetHasBias() const { return this->primitive->value.AsDeDepthwiseConv2D()->hasBias; } +int DeDepthwiseConv2D::GetKernelW() const { return this->primitive_->value.AsDeDepthwiseConv2D()->kernelW; } +int DeDepthwiseConv2D::GetKernelH() const { return this->primitive_->value.AsDeDepthwiseConv2D()->kernelH; } +int DeDepthwiseConv2D::GetStrideW() const { return this->primitive_->value.AsDeDepthwiseConv2D()->strideW; } +int DeDepthwiseConv2D::GetStrideH() const { return this->primitive_->value.AsDeDepthwiseConv2D()->strideH; } +int DeDepthwiseConv2D::GetPadMode() const { return this->primitive_->value.AsDeDepthwiseConv2D()->padMode; } +int DeDepthwiseConv2D::GetPadUp() const { return this->primitive_->value.AsDeDepthwiseConv2D()->padUp; } +int DeDepthwiseConv2D::GetPadDown() const { return this->primitive_->value.AsDeDepthwiseConv2D()->padDown; } +int DeDepthwiseConv2D::GetPadLeft() const { return this->primitive_->value.AsDeDepthwiseConv2D()->padLeft; } +int DeDepthwiseConv2D::GetPadRight() const { return this->primitive_->value.AsDeDepthwiseConv2D()->padRight; } +int DeDepthwiseConv2D::GetDilateW() const { return this->primitive_->value.AsDeDepthwiseConv2D()->dilateW; } +int DeDepthwiseConv2D::GetDilateH() const { return this->primitive_->value.AsDeDepthwiseConv2D()->dilateH; } +bool DeDepthwiseConv2D::GetHasBias() const { return this->primitive_->value.AsDeDepthwiseConv2D()->hasBias; } int DeDepthwiseConv2D::GetActivationType() const { - return this->primitive->value.AsDeDepthwiseConv2D()->activationType; + return this->primitive_->value.AsDeDepthwiseConv2D()->activationType; } void DeDepthwiseConv2D::SetFormat(int format) { - this->primitive->value.AsDeDepthwiseConv2D()->format = (schema::Format)format; + this->primitive_->value.AsDeDepthwiseConv2D()->format = (schema::Format)format; } void DeDepthwiseConv2D::SetChannelIn(int channel_in) { - this->primitive->value.AsDeDepthwiseConv2D()->channelIn = channel_in; + this->primitive_->value.AsDeDepthwiseConv2D()->channelIn = channel_in; } void DeDepthwiseConv2D::SetChannelMultiplier(int channel_multiplier) { - this->primitive->value.AsDeDepthwiseConv2D()->channelMultiplier = channel_multiplier; + this->primitive_->value.AsDeDepthwiseConv2D()->channelMultiplier = channel_multiplier; } -void DeDepthwiseConv2D::SetKernelW(int kernel_w) { this->primitive->value.AsDeDepthwiseConv2D()->kernelW = kernel_w; } -void DeDepthwiseConv2D::SetKernelH(int kernel_h) { this->primitive->value.AsDeDepthwiseConv2D()->kernelH = kernel_h; } -void DeDepthwiseConv2D::SetStrideW(int stride_w) { this->primitive->value.AsDeDepthwiseConv2D()->strideW = stride_w; } -void DeDepthwiseConv2D::SetStrideH(int stride_h) { this->primitive->value.AsDeDepthwiseConv2D()->strideH = stride_h; } +void DeDepthwiseConv2D::SetKernelW(int kernel_w) { this->primitive_->value.AsDeDepthwiseConv2D()->kernelW = kernel_w; } +void DeDepthwiseConv2D::SetKernelH(int kernel_h) { this->primitive_->value.AsDeDepthwiseConv2D()->kernelH = kernel_h; } +void DeDepthwiseConv2D::SetStrideW(int stride_w) { this->primitive_->value.AsDeDepthwiseConv2D()->strideW = stride_w; } +void DeDepthwiseConv2D::SetStrideH(int stride_h) { this->primitive_->value.AsDeDepthwiseConv2D()->strideH = stride_h; } void DeDepthwiseConv2D::SetPadMode(int pad_mode) { - this->primitive->value.AsDeDepthwiseConv2D()->padMode = (schema::PadMode)pad_mode; + this->primitive_->value.AsDeDepthwiseConv2D()->padMode = (schema::PadMode)pad_mode; } -void DeDepthwiseConv2D::SetPadUp(int pad_up) { this->primitive->value.AsDeDepthwiseConv2D()->padUp = pad_up; } -void DeDepthwiseConv2D::SetPadDown(int pad_down) { this->primitive->value.AsDeDepthwiseConv2D()->padDown = pad_down; } -void DeDepthwiseConv2D::SetPadLeft(int pad_left) { this->primitive->value.AsDeDepthwiseConv2D()->padLeft = pad_left; } +void DeDepthwiseConv2D::SetPadUp(int pad_up) { this->primitive_->value.AsDeDepthwiseConv2D()->padUp = pad_up; } +void DeDepthwiseConv2D::SetPadDown(int pad_down) { this->primitive_->value.AsDeDepthwiseConv2D()->padDown = pad_down; } +void DeDepthwiseConv2D::SetPadLeft(int pad_left) { this->primitive_->value.AsDeDepthwiseConv2D()->padLeft = pad_left; } void DeDepthwiseConv2D::SetPadRight(int pad_right) { - this->primitive->value.AsDeDepthwiseConv2D()->padRight = pad_right; + this->primitive_->value.AsDeDepthwiseConv2D()->padRight = pad_right; } -void DeDepthwiseConv2D::SetDilateW(int dilate_w) { this->primitive->value.AsDeDepthwiseConv2D()->dilateW = dilate_w; } -void DeDepthwiseConv2D::SetDilateH(int dilate_h) { this->primitive->value.AsDeDepthwiseConv2D()->dilateH = dilate_h; } -void DeDepthwiseConv2D::SetHasBias(bool has_bias) { this->primitive->value.AsDeDepthwiseConv2D()->hasBias = has_bias; } +void DeDepthwiseConv2D::SetDilateW(int dilate_w) { this->primitive_->value.AsDeDepthwiseConv2D()->dilateW = dilate_w; } +void DeDepthwiseConv2D::SetDilateH(int dilate_h) { this->primitive_->value.AsDeDepthwiseConv2D()->dilateH = dilate_h; } +void DeDepthwiseConv2D::SetHasBias(bool has_bias) { this->primitive_->value.AsDeDepthwiseConv2D()->hasBias = has_bias; } void DeDepthwiseConv2D::SetActivationType(int activation_type) { - this->primitive->value.AsDeDepthwiseConv2D()->activationType = (schema::ActivationType)activation_type; + this->primitive_->value.AsDeDepthwiseConv2D()->activationType = (schema::ActivationType)activation_type; } #else -int DeDepthwiseConv2D::GetFormat() const { return this->primitive->value_as_DeDepthwiseConv2D()->format(); } -int DeDepthwiseConv2D::GetChannelIn() const { return this->primitive->value_as_DeDepthwiseConv2D()->channelIn(); } +int DeDepthwiseConv2D::GetFormat() const { return this->primitive_->value_as_DeDepthwiseConv2D()->format(); } +int DeDepthwiseConv2D::GetChannelIn() const { return this->primitive_->value_as_DeDepthwiseConv2D()->channelIn(); } int DeDepthwiseConv2D::GetChannelMultiplier() const { - return this->primitive->value_as_DeDepthwiseConv2D()->channelMultiplier(); + return this->primitive_->value_as_DeDepthwiseConv2D()->channelMultiplier(); } -int DeDepthwiseConv2D::GetKernelW() const { return this->primitive->value_as_DeDepthwiseConv2D()->kernelW(); } -int DeDepthwiseConv2D::GetKernelH() const { return this->primitive->value_as_DeDepthwiseConv2D()->kernelH(); } -int DeDepthwiseConv2D::GetStrideW() const { return this->primitive->value_as_DeDepthwiseConv2D()->strideW(); } -int DeDepthwiseConv2D::GetStrideH() const { return this->primitive->value_as_DeDepthwiseConv2D()->strideH(); } -int DeDepthwiseConv2D::GetPadMode() const { return this->primitive->value_as_DeDepthwiseConv2D()->padMode(); } -int DeDepthwiseConv2D::GetPadUp() const { return this->primitive->value_as_DeDepthwiseConv2D()->padUp(); } -int DeDepthwiseConv2D::GetPadDown() const { return this->primitive->value_as_DeDepthwiseConv2D()->padDown(); } -int DeDepthwiseConv2D::GetPadLeft() const { return this->primitive->value_as_DeDepthwiseConv2D()->padLeft(); } -int DeDepthwiseConv2D::GetPadRight() const { return this->primitive->value_as_DeDepthwiseConv2D()->padRight(); } -int DeDepthwiseConv2D::GetDilateW() const { return this->primitive->value_as_DeDepthwiseConv2D()->dilateW(); } -int DeDepthwiseConv2D::GetDilateH() const { return this->primitive->value_as_DeDepthwiseConv2D()->dilateH(); } -bool DeDepthwiseConv2D::GetHasBias() const { return this->primitive->value_as_DeDepthwiseConv2D()->hasBias(); } +int DeDepthwiseConv2D::GetKernelW() const { return this->primitive_->value_as_DeDepthwiseConv2D()->kernelW(); } +int DeDepthwiseConv2D::GetKernelH() const { return this->primitive_->value_as_DeDepthwiseConv2D()->kernelH(); } +int DeDepthwiseConv2D::GetStrideW() const { return this->primitive_->value_as_DeDepthwiseConv2D()->strideW(); } +int DeDepthwiseConv2D::GetStrideH() const { return this->primitive_->value_as_DeDepthwiseConv2D()->strideH(); } +int DeDepthwiseConv2D::GetPadMode() const { return this->primitive_->value_as_DeDepthwiseConv2D()->padMode(); } +int DeDepthwiseConv2D::GetPadUp() const { return this->primitive_->value_as_DeDepthwiseConv2D()->padUp(); } +int DeDepthwiseConv2D::GetPadDown() const { return this->primitive_->value_as_DeDepthwiseConv2D()->padDown(); } +int DeDepthwiseConv2D::GetPadLeft() const { return this->primitive_->value_as_DeDepthwiseConv2D()->padLeft(); } +int DeDepthwiseConv2D::GetPadRight() const { return this->primitive_->value_as_DeDepthwiseConv2D()->padRight(); } +int DeDepthwiseConv2D::GetDilateW() const { return this->primitive_->value_as_DeDepthwiseConv2D()->dilateW(); } +int DeDepthwiseConv2D::GetDilateH() const { return this->primitive_->value_as_DeDepthwiseConv2D()->dilateH(); } +bool DeDepthwiseConv2D::GetHasBias() const { return this->primitive_->value_as_DeDepthwiseConv2D()->hasBias(); } int DeDepthwiseConv2D::GetActivationType() const { - return this->primitive->value_as_DeDepthwiseConv2D()->activationType(); + return this->primitive_->value_as_DeDepthwiseConv2D()->activationType(); } void DeDepthwiseConv2D::SetFormat(int format) {} @@ -119,7 +119,7 @@ int DeDepthwiseConv2D::InferShape(std::vector inputs_, MS_LOG(ERROR) << "output number is invalid"; return 1; } - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto weight = inputs_.at(1); diff --git a/mindspore/lite/src/ops/dedepthwise_conv2d.h b/mindspore/lite/src/ops/dedepthwise_conv2d.h index ed317016f11..689d2033bf3 100644 --- a/mindspore/lite/src/ops/dedepthwise_conv2d.h +++ b/mindspore/lite/src/ops/dedepthwise_conv2d.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_DE_DEPTHWISE_CONV2_D_H_ +#define LITE_MINDSPORE_LITE_C_OPS_DE_DEPTHWISE_CONV2_D_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_DE_DEPTHWISE_CONV2_D_H_ -#define LITE_MINDSPORE_LITE_C_OPS_DE_DEPTHWISE_CONV2_D_H_ namespace mindspore { namespace lite { class DeDepthwiseConv2D : public PrimitiveC { public: - explicit DeDepthwiseConv2D(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit DeDepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit DeDepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; diff --git a/mindspore/lite/src/ops/depth_to_space.cc b/mindspore/lite/src/ops/depth_to_space.cc index ec5ccf72c7e..b1b3e2d0261 100644 --- a/mindspore/lite/src/ops/depth_to_space.cc +++ b/mindspore/lite/src/ops/depth_to_space.cc @@ -19,16 +19,16 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int DepthToSpace::GetBlockSize() const { return this->primitive->value.AsDepthToSpace()->blockSize; } -int DepthToSpace::GetFormat() const { return this->primitive->value.AsDepthToSpace()->format; } +int DepthToSpace::GetBlockSize() const { return this->primitive_->value.AsDepthToSpace()->blockSize; } +int DepthToSpace::GetFormat() const { return this->primitive_->value.AsDepthToSpace()->format; } -void DepthToSpace::SetBlockSize(int block_size) { this->primitive->value.AsDepthToSpace()->blockSize = block_size; } -void DepthToSpace::SetFormat(int format) { this->primitive->value.AsDepthToSpace()->format = (schema::Format)format; } +void DepthToSpace::SetBlockSize(int block_size) { this->primitive_->value.AsDepthToSpace()->blockSize = block_size; } +void DepthToSpace::SetFormat(int format) { this->primitive_->value.AsDepthToSpace()->format = (schema::Format)format; } #else -int DepthToSpace::GetBlockSize() const { return this->primitive->value_as_DepthToSpace()->blockSize(); } -int DepthToSpace::GetFormat() const { return this->primitive->value_as_DepthToSpace()->format(); } +int DepthToSpace::GetBlockSize() const { return this->primitive_->value_as_DepthToSpace()->blockSize(); } +int DepthToSpace::GetFormat() const { return this->primitive_->value_as_DepthToSpace()->format(); } void DepthToSpace::SetBlockSize(int block_size) {} void DepthToSpace::SetFormat(int format) {} @@ -39,7 +39,7 @@ constexpr int kDepthToSpaceInputNum = 1; } // namespace int DepthToSpace::InferShape(std::vector inputs, std::vector outputs) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); if (outputs.size() != kDepthToSpaceOutputNum || inputs.size() != kDepthToSpaceInputNum) { MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); return RET_PARAM_INVALID; diff --git a/mindspore/lite/src/ops/depth_to_space.h b/mindspore/lite/src/ops/depth_to_space.h index 6ab1fc3075d..3ae52550eef 100644 --- a/mindspore/lite/src/ops/depth_to_space.h +++ b/mindspore/lite/src/ops/depth_to_space.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_DEPTH_TO_SPACE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_DEPTH_TO_SPACE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_DEPTH_TO_SPACE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_DEPTH_TO_SPACE_H_ namespace mindspore { namespace lite { class DepthToSpace : public PrimitiveC { public: - explicit DepthToSpace(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit DepthToSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit DepthToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetBlockSize() const; diff --git a/mindspore/lite/src/ops/depthwise_conv2d.cc b/mindspore/lite/src/ops/depthwise_conv2d.cc index de6c6acef3f..3903ce83493 100644 --- a/mindspore/lite/src/ops/depthwise_conv2d.cc +++ b/mindspore/lite/src/ops/depthwise_conv2d.cc @@ -19,72 +19,74 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int DepthwiseConv2D::GetFormat() const { return this->primitive->value.AsDepthwiseConv2D()->format; } -int DepthwiseConv2D::GetChannelIn() const { return this->primitive->value.AsDepthwiseConv2D()->channelIn; } +int DepthwiseConv2D::GetFormat() const { return this->primitive_->value.AsDepthwiseConv2D()->format; } +int DepthwiseConv2D::GetChannelIn() const { return this->primitive_->value.AsDepthwiseConv2D()->channelIn; } int DepthwiseConv2D::GetChannelMultiplier() const { - return this->primitive->value.AsDepthwiseConv2D()->channelMultiplier; + return this->primitive_->value.AsDepthwiseConv2D()->channelMultiplier; } -int DepthwiseConv2D::GetKernelW() const { return this->primitive->value.AsDepthwiseConv2D()->kernelW; } -int DepthwiseConv2D::GetKernelH() const { return this->primitive->value.AsDepthwiseConv2D()->kernelH; } -int DepthwiseConv2D::GetStrideW() const { return this->primitive->value.AsDepthwiseConv2D()->strideW; } -int DepthwiseConv2D::GetStrideH() const { return this->primitive->value.AsDepthwiseConv2D()->strideH; } -int DepthwiseConv2D::GetPadMode() const { return this->primitive->value.AsDepthwiseConv2D()->padMode; } -int DepthwiseConv2D::GetPadUp() const { return this->primitive->value.AsDepthwiseConv2D()->padUp; } -int DepthwiseConv2D::GetPadDown() const { return this->primitive->value.AsDepthwiseConv2D()->padDown; } -int DepthwiseConv2D::GetPadLeft() const { return this->primitive->value.AsDepthwiseConv2D()->padLeft; } -int DepthwiseConv2D::GetPadRight() const { return this->primitive->value.AsDepthwiseConv2D()->padRight; } -int DepthwiseConv2D::GetDilateW() const { return this->primitive->value.AsDepthwiseConv2D()->dilateW; } -int DepthwiseConv2D::GetDilateH() const { return this->primitive->value.AsDepthwiseConv2D()->dilateH; } -bool DepthwiseConv2D::GetHasBias() const { return this->primitive->value.AsDepthwiseConv2D()->hasBias; } -int DepthwiseConv2D::GetActivationType() const { return this->primitive->value.AsDepthwiseConv2D()->activationType; } +int DepthwiseConv2D::GetKernelW() const { return this->primitive_->value.AsDepthwiseConv2D()->kernelW; } +int DepthwiseConv2D::GetKernelH() const { return this->primitive_->value.AsDepthwiseConv2D()->kernelH; } +int DepthwiseConv2D::GetStrideW() const { return this->primitive_->value.AsDepthwiseConv2D()->strideW; } +int DepthwiseConv2D::GetStrideH() const { return this->primitive_->value.AsDepthwiseConv2D()->strideH; } +int DepthwiseConv2D::GetPadMode() const { return this->primitive_->value.AsDepthwiseConv2D()->padMode; } +int DepthwiseConv2D::GetPadUp() const { return this->primitive_->value.AsDepthwiseConv2D()->padUp; } +int DepthwiseConv2D::GetPadDown() const { return this->primitive_->value.AsDepthwiseConv2D()->padDown; } +int DepthwiseConv2D::GetPadLeft() const { return this->primitive_->value.AsDepthwiseConv2D()->padLeft; } +int DepthwiseConv2D::GetPadRight() const { return this->primitive_->value.AsDepthwiseConv2D()->padRight; } +int DepthwiseConv2D::GetDilateW() const { return this->primitive_->value.AsDepthwiseConv2D()->dilateW; } +int DepthwiseConv2D::GetDilateH() const { return this->primitive_->value.AsDepthwiseConv2D()->dilateH; } +bool DepthwiseConv2D::GetHasBias() const { return this->primitive_->value.AsDepthwiseConv2D()->hasBias; } +int DepthwiseConv2D::GetActivationType() const { return this->primitive_->value.AsDepthwiseConv2D()->activationType; } void DepthwiseConv2D::SetFormat(int format) { - this->primitive->value.AsDepthwiseConv2D()->format = (schema::Format)format; + this->primitive_->value.AsDepthwiseConv2D()->format = (schema::Format)format; } void DepthwiseConv2D::SetChannelIn(int channel_in) { - this->primitive->value.AsDepthwiseConv2D()->channelIn = channel_in; + this->primitive_->value.AsDepthwiseConv2D()->channelIn = channel_in; } void DepthwiseConv2D::SetChannelMultiplier(int channel_multiplier) { - this->primitive->value.AsDepthwiseConv2D()->channelMultiplier = channel_multiplier; + this->primitive_->value.AsDepthwiseConv2D()->channelMultiplier = channel_multiplier; } -void DepthwiseConv2D::SetKernelW(int kernel_w) { this->primitive->value.AsDepthwiseConv2D()->kernelW = kernel_w; } -void DepthwiseConv2D::SetKernelH(int kernel_h) { this->primitive->value.AsDepthwiseConv2D()->kernelH = kernel_h; } -void DepthwiseConv2D::SetStrideW(int stride_w) { this->primitive->value.AsDepthwiseConv2D()->strideW = stride_w; } -void DepthwiseConv2D::SetStrideH(int stride_h) { this->primitive->value.AsDepthwiseConv2D()->strideH = stride_h; } +void DepthwiseConv2D::SetKernelW(int kernel_w) { this->primitive_->value.AsDepthwiseConv2D()->kernelW = kernel_w; } +void DepthwiseConv2D::SetKernelH(int kernel_h) { this->primitive_->value.AsDepthwiseConv2D()->kernelH = kernel_h; } +void DepthwiseConv2D::SetStrideW(int stride_w) { this->primitive_->value.AsDepthwiseConv2D()->strideW = stride_w; } +void DepthwiseConv2D::SetStrideH(int stride_h) { this->primitive_->value.AsDepthwiseConv2D()->strideH = stride_h; } void DepthwiseConv2D::SetPadMode(int pad_mode) { - this->primitive->value.AsDepthwiseConv2D()->padMode = (schema::PadMode)pad_mode; + this->primitive_->value.AsDepthwiseConv2D()->padMode = (schema::PadMode)pad_mode; } -void DepthwiseConv2D::SetPadUp(int pad_up) { this->primitive->value.AsDepthwiseConv2D()->padUp = pad_up; } -void DepthwiseConv2D::SetPadDown(int pad_down) { this->primitive->value.AsDepthwiseConv2D()->padDown = pad_down; } -void DepthwiseConv2D::SetPadLeft(int pad_left) { this->primitive->value.AsDepthwiseConv2D()->padLeft = pad_left; } -void DepthwiseConv2D::SetPadRight(int pad_right) { this->primitive->value.AsDepthwiseConv2D()->padRight = pad_right; } -void DepthwiseConv2D::SetDilateW(int dilate_w) { this->primitive->value.AsDepthwiseConv2D()->dilateW = dilate_w; } -void DepthwiseConv2D::SetDilateH(int dilate_h) { this->primitive->value.AsDepthwiseConv2D()->dilateH = dilate_h; } -void DepthwiseConv2D::SetHasBias(bool has_bias) { this->primitive->value.AsDepthwiseConv2D()->hasBias = has_bias; } +void DepthwiseConv2D::SetPadUp(int pad_up) { this->primitive_->value.AsDepthwiseConv2D()->padUp = pad_up; } +void DepthwiseConv2D::SetPadDown(int pad_down) { this->primitive_->value.AsDepthwiseConv2D()->padDown = pad_down; } +void DepthwiseConv2D::SetPadLeft(int pad_left) { this->primitive_->value.AsDepthwiseConv2D()->padLeft = pad_left; } +void DepthwiseConv2D::SetPadRight(int pad_right) { this->primitive_->value.AsDepthwiseConv2D()->padRight = pad_right; } +void DepthwiseConv2D::SetDilateW(int dilate_w) { this->primitive_->value.AsDepthwiseConv2D()->dilateW = dilate_w; } +void DepthwiseConv2D::SetDilateH(int dilate_h) { this->primitive_->value.AsDepthwiseConv2D()->dilateH = dilate_h; } +void DepthwiseConv2D::SetHasBias(bool has_bias) { this->primitive_->value.AsDepthwiseConv2D()->hasBias = has_bias; } void DepthwiseConv2D::SetActivationType(int activation_type) { - this->primitive->value.AsDepthwiseConv2D()->activationType = (schema::ActivationType)activation_type; + this->primitive_->value.AsDepthwiseConv2D()->activationType = (schema::ActivationType)activation_type; } #else -int DepthwiseConv2D::GetFormat() const { return this->primitive->value_as_DepthwiseConv2D()->format(); } -int DepthwiseConv2D::GetChannelIn() const { return this->primitive->value_as_DepthwiseConv2D()->channelIn(); } +int DepthwiseConv2D::GetFormat() const { return this->primitive_->value_as_DepthwiseConv2D()->format(); } +int DepthwiseConv2D::GetChannelIn() const { return this->primitive_->value_as_DepthwiseConv2D()->channelIn(); } int DepthwiseConv2D::GetChannelMultiplier() const { - return this->primitive->value_as_DepthwiseConv2D()->channelMultiplier(); + return this->primitive_->value_as_DepthwiseConv2D()->channelMultiplier(); +} +int DepthwiseConv2D::GetKernelW() const { return this->primitive_->value_as_DepthwiseConv2D()->kernelW(); } +int DepthwiseConv2D::GetKernelH() const { return this->primitive_->value_as_DepthwiseConv2D()->kernelH(); } +int DepthwiseConv2D::GetStrideW() const { return this->primitive_->value_as_DepthwiseConv2D()->strideW(); } +int DepthwiseConv2D::GetStrideH() const { return this->primitive_->value_as_DepthwiseConv2D()->strideH(); } +int DepthwiseConv2D::GetPadMode() const { return this->primitive_->value_as_DepthwiseConv2D()->padMode(); } +int DepthwiseConv2D::GetPadUp() const { return this->primitive_->value_as_DepthwiseConv2D()->padUp(); } +int DepthwiseConv2D::GetPadDown() const { return this->primitive_->value_as_DepthwiseConv2D()->padDown(); } +int DepthwiseConv2D::GetPadLeft() const { return this->primitive_->value_as_DepthwiseConv2D()->padLeft(); } +int DepthwiseConv2D::GetPadRight() const { return this->primitive_->value_as_DepthwiseConv2D()->padRight(); } +int DepthwiseConv2D::GetDilateW() const { return this->primitive_->value_as_DepthwiseConv2D()->dilateW(); } +int DepthwiseConv2D::GetDilateH() const { return this->primitive_->value_as_DepthwiseConv2D()->dilateH(); } +bool DepthwiseConv2D::GetHasBias() const { return this->primitive_->value_as_DepthwiseConv2D()->hasBias(); } +int DepthwiseConv2D::GetActivationType() const { + return this->primitive_->value_as_DepthwiseConv2D()->activationType(); } -int DepthwiseConv2D::GetKernelW() const { return this->primitive->value_as_DepthwiseConv2D()->kernelW(); } -int DepthwiseConv2D::GetKernelH() const { return this->primitive->value_as_DepthwiseConv2D()->kernelH(); } -int DepthwiseConv2D::GetStrideW() const { return this->primitive->value_as_DepthwiseConv2D()->strideW(); } -int DepthwiseConv2D::GetStrideH() const { return this->primitive->value_as_DepthwiseConv2D()->strideH(); } -int DepthwiseConv2D::GetPadMode() const { return this->primitive->value_as_DepthwiseConv2D()->padMode(); } -int DepthwiseConv2D::GetPadUp() const { return this->primitive->value_as_DepthwiseConv2D()->padUp(); } -int DepthwiseConv2D::GetPadDown() const { return this->primitive->value_as_DepthwiseConv2D()->padDown(); } -int DepthwiseConv2D::GetPadLeft() const { return this->primitive->value_as_DepthwiseConv2D()->padLeft(); } -int DepthwiseConv2D::GetPadRight() const { return this->primitive->value_as_DepthwiseConv2D()->padRight(); } -int DepthwiseConv2D::GetDilateW() const { return this->primitive->value_as_DepthwiseConv2D()->dilateW(); } -int DepthwiseConv2D::GetDilateH() const { return this->primitive->value_as_DepthwiseConv2D()->dilateH(); } -bool DepthwiseConv2D::GetHasBias() const { return this->primitive->value_as_DepthwiseConv2D()->hasBias(); } -int DepthwiseConv2D::GetActivationType() const { return this->primitive->value_as_DepthwiseConv2D()->activationType(); } void DepthwiseConv2D::SetFormat(int format) {} void DepthwiseConv2D::SetChannelIn(int channel_in) {} @@ -113,7 +115,7 @@ int DepthwiseConv2D::InferShape(std::vector inputs_, MS_LOG(ERROR) << "output number is invalid"; return 1; } - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto weight = inputs_.at(1); diff --git a/mindspore/lite/src/ops/depthwise_conv2d.h b/mindspore/lite/src/ops/depthwise_conv2d.h index b61505feee3..eb60575341c 100644 --- a/mindspore/lite/src/ops/depthwise_conv2d.h +++ b/mindspore/lite/src/ops/depthwise_conv2d.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_DEPTHWISE_CONV2_D_H_ +#define LITE_MINDSPORE_LITE_C_OPS_DEPTHWISE_CONV2_D_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_DEPTHWISE_CONV2_D_H_ -#define LITE_MINDSPORE_LITE_C_OPS_DEPTHWISE_CONV2_D_H_ namespace mindspore { namespace lite { class DepthwiseConv2D : public PrimitiveC { public: - explicit DepthwiseConv2D(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit DepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit DepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; diff --git a/mindspore/lite/src/ops/detection_post_process.cc b/mindspore/lite/src/ops/detection_post_process.cc index b1256a7840a..31bb2eb3a3e 100644 --- a/mindspore/lite/src/ops/detection_post_process.cc +++ b/mindspore/lite/src/ops/detection_post_process.cc @@ -19,100 +19,104 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int DetectionPostProcess::GetFormat() const { return this->primitive->value.AsDetectionPostProcess()->format; } -int DetectionPostProcess::GetInputSize() const { return this->primitive->value.AsDetectionPostProcess()->inputSize; } -float DetectionPostProcess::GetHScale() const { return this->primitive->value.AsDetectionPostProcess()->hScale; } -float DetectionPostProcess::GetWScale() const { return this->primitive->value.AsDetectionPostProcess()->wScale; } -float DetectionPostProcess::GetXScale() const { return this->primitive->value.AsDetectionPostProcess()->xScale; } -float DetectionPostProcess::GetYScale() const { return this->primitive->value.AsDetectionPostProcess()->yScale; } +int DetectionPostProcess::GetFormat() const { return this->primitive_->value.AsDetectionPostProcess()->format; } +int DetectionPostProcess::GetInputSize() const { return this->primitive_->value.AsDetectionPostProcess()->inputSize; } +float DetectionPostProcess::GetHScale() const { return this->primitive_->value.AsDetectionPostProcess()->hScale; } +float DetectionPostProcess::GetWScale() const { return this->primitive_->value.AsDetectionPostProcess()->wScale; } +float DetectionPostProcess::GetXScale() const { return this->primitive_->value.AsDetectionPostProcess()->xScale; } +float DetectionPostProcess::GetYScale() const { return this->primitive_->value.AsDetectionPostProcess()->yScale; } float DetectionPostProcess::GetNmsIouThreshold() const { - return this->primitive->value.AsDetectionPostProcess()->NmsIouThreshold; + return this->primitive_->value.AsDetectionPostProcess()->NmsIouThreshold; } float DetectionPostProcess::GetNmsScoreThreshold() const { - return this->primitive->value.AsDetectionPostProcess()->NmsScoreThreshold; + return this->primitive_->value.AsDetectionPostProcess()->NmsScoreThreshold; } long DetectionPostProcess::GetMaxDetections() const { - return this->primitive->value.AsDetectionPostProcess()->MaxDetections; + return this->primitive_->value.AsDetectionPostProcess()->MaxDetections; } long DetectionPostProcess::GetDetectionsPreClass() const { - return this->primitive->value.AsDetectionPostProcess()->DetectionsPreClass; + return this->primitive_->value.AsDetectionPostProcess()->DetectionsPreClass; } long DetectionPostProcess::GetMaxClassesPreDetection() const { - return this->primitive->value.AsDetectionPostProcess()->MaxClassesPreDetection; + return this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection; +} +long DetectionPostProcess::GetNumClasses() const { + return this->primitive_->value.AsDetectionPostProcess()->NumClasses; } -long DetectionPostProcess::GetNumClasses() const { return this->primitive->value.AsDetectionPostProcess()->NumClasses; } bool DetectionPostProcess::GetUseRegularNms() const { - return this->primitive->value.AsDetectionPostProcess()->UseRegularNms; + return this->primitive_->value.AsDetectionPostProcess()->UseRegularNms; } void DetectionPostProcess::SetFormat(int format) { - this->primitive->value.AsDetectionPostProcess()->format = (schema::Format)format; + this->primitive_->value.AsDetectionPostProcess()->format = (schema::Format)format; } void DetectionPostProcess::SetInputSize(int input_size) { - this->primitive->value.AsDetectionPostProcess()->inputSize = input_size; + this->primitive_->value.AsDetectionPostProcess()->inputSize = input_size; } void DetectionPostProcess::SetHScale(float h_scale) { - this->primitive->value.AsDetectionPostProcess()->hScale = h_scale; + this->primitive_->value.AsDetectionPostProcess()->hScale = h_scale; } void DetectionPostProcess::SetWScale(float w_scale) { - this->primitive->value.AsDetectionPostProcess()->wScale = w_scale; + this->primitive_->value.AsDetectionPostProcess()->wScale = w_scale; } void DetectionPostProcess::SetXScale(float x_scale) { - this->primitive->value.AsDetectionPostProcess()->xScale = x_scale; + this->primitive_->value.AsDetectionPostProcess()->xScale = x_scale; } void DetectionPostProcess::SetYScale(float y_scale) { - this->primitive->value.AsDetectionPostProcess()->yScale = y_scale; + this->primitive_->value.AsDetectionPostProcess()->yScale = y_scale; } void DetectionPostProcess::SetNmsIouThreshold(float nms_iou_threshold) { - this->primitive->value.AsDetectionPostProcess()->NmsIouThreshold = nms_iou_threshold; + this->primitive_->value.AsDetectionPostProcess()->NmsIouThreshold = nms_iou_threshold; } void DetectionPostProcess::SetNmsScoreThreshold(float nms_score_threshold) { - this->primitive->value.AsDetectionPostProcess()->NmsScoreThreshold = nms_score_threshold; + this->primitive_->value.AsDetectionPostProcess()->NmsScoreThreshold = nms_score_threshold; } void DetectionPostProcess::SetMaxDetections(long max_detections) { - this->primitive->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_detections; + this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_detections; } void DetectionPostProcess::SetDetectionsPreClass(long detections_pre_class) { - this->primitive->value.AsDetectionPostProcess()->DetectionsPreClass = detections_pre_class; + this->primitive_->value.AsDetectionPostProcess()->DetectionsPreClass = detections_pre_class; } void DetectionPostProcess::SetMaxClassesPreDetection(long max_classes_pre_detection) { - this->primitive->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_classes_pre_detection; + this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_classes_pre_detection; } void DetectionPostProcess::SetNumClasses(long num_classes) { - this->primitive->value.AsDetectionPostProcess()->NumClasses = num_classes; + this->primitive_->value.AsDetectionPostProcess()->NumClasses = num_classes; } void DetectionPostProcess::SetUseRegularNms(bool use_regular_nms) { - this->primitive->value.AsDetectionPostProcess()->UseRegularNms = use_regular_nms; + this->primitive_->value.AsDetectionPostProcess()->UseRegularNms = use_regular_nms; } #else -int DetectionPostProcess::GetFormat() const { return this->primitive->value_as_DetectionPostProcess()->format(); } -int DetectionPostProcess::GetInputSize() const { return this->primitive->value_as_DetectionPostProcess()->inputSize(); } -float DetectionPostProcess::GetHScale() const { return this->primitive->value_as_DetectionPostProcess()->hScale(); } -float DetectionPostProcess::GetWScale() const { return this->primitive->value_as_DetectionPostProcess()->wScale(); } -float DetectionPostProcess::GetXScale() const { return this->primitive->value_as_DetectionPostProcess()->xScale(); } -float DetectionPostProcess::GetYScale() const { return this->primitive->value_as_DetectionPostProcess()->yScale(); } +int DetectionPostProcess::GetFormat() const { return this->primitive_->value_as_DetectionPostProcess()->format(); } +int DetectionPostProcess::GetInputSize() const { + return this->primitive_->value_as_DetectionPostProcess()->inputSize(); +} +float DetectionPostProcess::GetHScale() const { return this->primitive_->value_as_DetectionPostProcess()->hScale(); } +float DetectionPostProcess::GetWScale() const { return this->primitive_->value_as_DetectionPostProcess()->wScale(); } +float DetectionPostProcess::GetXScale() const { return this->primitive_->value_as_DetectionPostProcess()->xScale(); } +float DetectionPostProcess::GetYScale() const { return this->primitive_->value_as_DetectionPostProcess()->yScale(); } float DetectionPostProcess::GetNmsIouThreshold() const { - return this->primitive->value_as_DetectionPostProcess()->NmsIouThreshold(); + return this->primitive_->value_as_DetectionPostProcess()->NmsIouThreshold(); } float DetectionPostProcess::GetNmsScoreThreshold() const { - return this->primitive->value_as_DetectionPostProcess()->NmsScoreThreshold(); + return this->primitive_->value_as_DetectionPostProcess()->NmsScoreThreshold(); } long DetectionPostProcess::GetMaxDetections() const { - return this->primitive->value_as_DetectionPostProcess()->MaxDetections(); + return this->primitive_->value_as_DetectionPostProcess()->MaxDetections(); } long DetectionPostProcess::GetDetectionsPreClass() const { - return this->primitive->value_as_DetectionPostProcess()->DetectionsPreClass(); + return this->primitive_->value_as_DetectionPostProcess()->DetectionsPreClass(); } long DetectionPostProcess::GetMaxClassesPreDetection() const { - return this->primitive->value_as_DetectionPostProcess()->MaxClassesPreDetection(); + return this->primitive_->value_as_DetectionPostProcess()->MaxClassesPreDetection(); } long DetectionPostProcess::GetNumClasses() const { - return this->primitive->value_as_DetectionPostProcess()->NumClasses(); + return this->primitive_->value_as_DetectionPostProcess()->NumClasses(); } bool DetectionPostProcess::GetUseRegularNms() const { - return this->primitive->value_as_DetectionPostProcess()->UseRegularNms(); + return this->primitive_->value_as_DetectionPostProcess()->UseRegularNms(); } void DetectionPostProcess::SetFormat(int format) {} diff --git a/mindspore/lite/src/ops/detection_post_process.h b/mindspore/lite/src/ops/detection_post_process.h index 2aa87f656b7..ab8624f0bf5 100644 --- a/mindspore/lite/src/ops/detection_post_process.h +++ b/mindspore/lite/src/ops/detection_post_process.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_DETECTION_POST_PROCESS_H_ +#define LITE_MINDSPORE_LITE_C_OPS_DETECTION_POST_PROCESS_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_DETECTION_POST_PROCESS_H_ -#define LITE_MINDSPORE_LITE_C_OPS_DETECTION_POST_PROCESS_H_ namespace mindspore { namespace lite { class DetectionPostProcess : public PrimitiveC { public: - explicit DetectionPostProcess(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit DetectionPostProcess(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit DetectionPostProcess(schema::Primitive *primitive) : PrimitiveC(primitive) {} int GetFormat() const; int GetInputSize() const; diff --git a/mindspore/lite/src/ops/div.cc b/mindspore/lite/src/ops/div.cc index cd042ab5091..93da12cc7e2 100644 --- a/mindspore/lite/src/ops/div.cc +++ b/mindspore/lite/src/ops/div.cc @@ -19,15 +19,15 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Div::GetActivationType() const { return this->primitive->value.AsDiv()->activationType; } +int Div::GetActivationType() const { return this->primitive_->value.AsDiv()->activationType; } void Div::SetActivationType(int activation_type) { - this->primitive->value.AsDiv()->activationType = (schema::ActivationType)activation_type; + this->primitive_->value.AsDiv()->activationType = (schema::ActivationType)activation_type; } #else -int Div::GetActivationType() const { return this->primitive->value_as_Div()->activationType(); } +int Div::GetActivationType() const { return this->primitive_->value_as_Div()->activationType(); } void Div::SetActivationType(int activation_type) {} #endif diff --git a/mindspore/lite/src/ops/div.h b/mindspore/lite/src/ops/div.h index b6e0ca344f7..043275cfcde 100644 --- a/mindspore/lite/src/ops/div.h +++ b/mindspore/lite/src/ops/div.h @@ -14,26 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_DIV_H_ +#define LITE_MINDSPORE_LITE_C_OPS_DIV_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/arithmetic.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_DIV_H_ -#define LITE_MINDSPORE_LITE_C_OPS_DIV_H_ - namespace mindspore { namespace lite { class Div : public Arithmetic { public: - explicit Div(OriginPrimitive *primitive) : Arithmetic(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Div(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + explicit Div(schema::Primitive *primitive) : Arithmetic(primitive) {} int GetActivationType() const; void SetActivationType(int activation_type); diff --git a/mindspore/lite/src/ops/dropout.cc b/mindspore/lite/src/ops/dropout.cc index 83835811c75..b381e8f031f 100644 --- a/mindspore/lite/src/ops/dropout.cc +++ b/mindspore/lite/src/ops/dropout.cc @@ -19,13 +19,13 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -float Dropout::GetRatio() const { return this->primitive->value.AsDropout()->ratio; } +float Dropout::GetRatio() const { return this->primitive_->value.AsDropout()->ratio; } -void Dropout::SetRatio(float ratio) { this->primitive->value.AsDropout()->ratio = ratio; } +void Dropout::SetRatio(float ratio) { this->primitive_->value.AsDropout()->ratio = ratio; } #else -float Dropout::GetRatio() const { return this->primitive->value_as_Dropout()->ratio(); } +float Dropout::GetRatio() const { return this->primitive_->value_as_Dropout()->ratio(); } void Dropout::SetRatio(float ratio) {} #endif diff --git a/mindspore/lite/src/ops/dropout.h b/mindspore/lite/src/ops/dropout.h index a59302e89b5..3804ad16757 100644 --- a/mindspore/lite/src/ops/dropout.h +++ b/mindspore/lite/src/ops/dropout.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_DROPOUT_H_ +#define LITE_MINDSPORE_LITE_C_OPS_DROPOUT_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_DROPOUT_H_ -#define LITE_MINDSPORE_LITE_C_OPS_DROPOUT_H_ namespace mindspore { namespace lite { class Dropout : public PrimitiveC { public: - explicit Dropout(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Dropout(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Dropout(schema::Primitive *primitive) : PrimitiveC(primitive) {} float GetRatio() const; void SetRatio(float ratio); diff --git a/mindspore/lite/src/ops/eltwise.cc b/mindspore/lite/src/ops/eltwise.cc index 95eafc7cfb0..3c5e5365120 100644 --- a/mindspore/lite/src/ops/eltwise.cc +++ b/mindspore/lite/src/ops/eltwise.cc @@ -19,13 +19,13 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Eltwise::GetMode() const { return this->primitive->value.AsEltwise()->mode; } +int Eltwise::GetMode() const { return this->primitive_->value.AsEltwise()->mode; } -void Eltwise::SetMode(int mode) { this->primitive->value.AsEltwise()->mode = (schema::EltwiseMode)mode; } +void Eltwise::SetMode(int mode) { this->primitive_->value.AsEltwise()->mode = (schema::EltwiseMode)mode; } #else -int Eltwise::GetMode() const { return this->primitive->value_as_Eltwise()->mode(); } +int Eltwise::GetMode() const { return this->primitive_->value_as_Eltwise()->mode(); } void Eltwise::SetMode(int mode) {} #endif diff --git a/mindspore/lite/src/ops/eltwise.h b/mindspore/lite/src/ops/eltwise.h index 9ea6ac81f9a..0d24bb550d8 100644 --- a/mindspore/lite/src/ops/eltwise.h +++ b/mindspore/lite/src/ops/eltwise.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_ELTWISE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_ELTWISE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_ELTWISE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_ELTWISE_H_ namespace mindspore { namespace lite { class Eltwise : public PrimitiveC { public: - explicit Eltwise(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Eltwise(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Eltwise(schema::Primitive *primitive) : PrimitiveC(primitive) {} int GetMode() const; void SetMode(int mode); diff --git a/mindspore/lite/src/ops/elu.cc b/mindspore/lite/src/ops/elu.cc index 2c2ad09d69e..6c164e45ad7 100644 --- a/mindspore/lite/src/ops/elu.cc +++ b/mindspore/lite/src/ops/elu.cc @@ -19,13 +19,13 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -float Elu::GetAlpha() const { return this->primitive->value.AsElu()->alpha; } +float Elu::GetAlpha() const { return this->primitive_->value.AsElu()->alpha; } -void Elu::SetAlpha(float alpha) { this->primitive->value.AsElu()->alpha = alpha; } +void Elu::SetAlpha(float alpha) { this->primitive_->value.AsElu()->alpha = alpha; } #else -float Elu::GetAlpha() const { return this->primitive->value_as_Elu()->alpha(); } +float Elu::GetAlpha() const { return this->primitive_->value_as_Elu()->alpha(); } void Elu::SetAlpha(float alpha) {} #endif diff --git a/mindspore/lite/src/ops/elu.h b/mindspore/lite/src/ops/elu.h index e85acee30d5..814ff4b4e77 100644 --- a/mindspore/lite/src/ops/elu.h +++ b/mindspore/lite/src/ops/elu.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_ELU_H_ +#define LITE_MINDSPORE_LITE_C_OPS_ELU_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_ELU_H_ -#define LITE_MINDSPORE_LITE_C_OPS_ELU_H_ namespace mindspore { namespace lite { class Elu : public PrimitiveC { public: - explicit Elu(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Elu(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Elu(schema::Primitive *primitive) : PrimitiveC(primitive) {} float GetAlpha() const; void SetAlpha(float alpha); diff --git a/mindspore/lite/src/ops/embedding_lookup.cc b/mindspore/lite/src/ops/embedding_lookup.cc index fa8e8fa0949..4a4b391529f 100644 --- a/mindspore/lite/src/ops/embedding_lookup.cc +++ b/mindspore/lite/src/ops/embedding_lookup.cc @@ -19,19 +19,19 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -float EmbeddingLookup::GetMaxNorm() const { return this->primitive->value.AsEmbeddingLookup()->maxNorm; } +float EmbeddingLookup::GetMaxNorm() const { return this->primitive_->value.AsEmbeddingLookup()->maxNorm; } -void EmbeddingLookup::SetMaxNorm(float max_norm) { this->primitive->value.AsEmbeddingLookup()->maxNorm = max_norm; } +void EmbeddingLookup::SetMaxNorm(float max_norm) { this->primitive_->value.AsEmbeddingLookup()->maxNorm = max_norm; } #else -float EmbeddingLookup::GetMaxNorm() const { return this->primitive->value_as_EmbeddingLookup()->maxNorm(); } +float EmbeddingLookup::GetMaxNorm() const { return this->primitive_->value_as_EmbeddingLookup()->maxNorm(); } void EmbeddingLookup::SetMaxNorm(float max_norm) {} #endif int EmbeddingLookup::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); if (inputs_.size() < kDoubleNum) { MS_LOG(ERROR) << "Embedding Lookup should have at least two inputs"; return RET_INPUT_TENSOR_ERROR; diff --git a/mindspore/lite/src/ops/embedding_lookup.h b/mindspore/lite/src/ops/embedding_lookup.h index 76105f3f841..a0f42137942 100644 --- a/mindspore/lite/src/ops/embedding_lookup.h +++ b/mindspore/lite/src/ops/embedding_lookup.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_EMBEDDING_LOOKUP_H_ +#define LITE_MINDSPORE_LITE_C_OPS_EMBEDDING_LOOKUP_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_EMBEDDING_LOOKUP_H_ -#define LITE_MINDSPORE_LITE_C_OPS_EMBEDDING_LOOKUP_H_ namespace mindspore { namespace lite { class EmbeddingLookup : public PrimitiveC { public: - explicit EmbeddingLookup(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit EmbeddingLookup(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit EmbeddingLookup(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; float GetMaxNorm() const; diff --git a/mindspore/lite/src/ops/embedding_lookup_sparse.cc b/mindspore/lite/src/ops/embedding_lookup_sparse.cc index c443af5353c..eb37231f9b3 100644 --- a/mindspore/lite/src/ops/embedding_lookup_sparse.cc +++ b/mindspore/lite/src/ops/embedding_lookup_sparse.cc @@ -20,35 +20,35 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE std::vector EmbeddingLookupSparse::GetSpIds() const { - return this->primitive->value.AsEmbeddingLookupSparse()->spIds; + return this->primitive_->value.AsEmbeddingLookupSparse()->spIds; } std::vector EmbeddingLookupSparse::GetSpWeights() const { - return this->primitive->value.AsEmbeddingLookupSparse()->spWeights; + return this->primitive_->value.AsEmbeddingLookupSparse()->spWeights; } -float EmbeddingLookupSparse::GetMaxNortm() const { return this->primitive->value.AsEmbeddingLookupSparse()->maxNortm; } +float EmbeddingLookupSparse::GetMaxNortm() const { return this->primitive_->value.AsEmbeddingLookupSparse()->maxNortm; } void EmbeddingLookupSparse::SetSpIds(const std::vector &sp_ids) { - this->primitive->value.AsEmbeddingLookupSparse()->spIds = sp_ids; + this->primitive_->value.AsEmbeddingLookupSparse()->spIds = sp_ids; } void EmbeddingLookupSparse::SetSpWeights(const std::vector &sp_weights) { - this->primitive->value.AsEmbeddingLookupSparse()->spWeights = sp_weights; + this->primitive_->value.AsEmbeddingLookupSparse()->spWeights = sp_weights; } void EmbeddingLookupSparse::SetMaxNortm(float max_nortm) { - this->primitive->value.AsEmbeddingLookupSparse()->maxNortm = max_nortm; + this->primitive_->value.AsEmbeddingLookupSparse()->maxNortm = max_nortm; } #else std::vector EmbeddingLookupSparse::GetSpIds() const { - auto fb_vector = this->primitive->value_as_EmbeddingLookupSparse()->spIds(); + auto fb_vector = this->primitive_->value_as_EmbeddingLookupSparse()->spIds(); return std::vector(fb_vector->begin(), fb_vector->end()); } std::vector EmbeddingLookupSparse::GetSpWeights() const { - auto fb_vector = this->primitive->value_as_EmbeddingLookupSparse()->spWeights(); + auto fb_vector = this->primitive_->value_as_EmbeddingLookupSparse()->spWeights(); return std::vector(fb_vector->begin(), fb_vector->end()); } float EmbeddingLookupSparse::GetMaxNortm() const { - return this->primitive->value_as_EmbeddingLookupSparse()->maxNortm(); + return this->primitive_->value_as_EmbeddingLookupSparse()->maxNortm(); } void EmbeddingLookupSparse::SetSpIds(const std::vector &sp_ids) {} diff --git a/mindspore/lite/src/ops/embedding_lookup_sparse.h b/mindspore/lite/src/ops/embedding_lookup_sparse.h index 836ae232688..d35f05a805c 100644 --- a/mindspore/lite/src/ops/embedding_lookup_sparse.h +++ b/mindspore/lite/src/ops/embedding_lookup_sparse.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_EMBEDDING_LOOKUP_SPARSE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_EMBEDDING_LOOKUP_SPARSE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_EMBEDDING_LOOKUP_SPARSE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_EMBEDDING_LOOKUP_SPARSE_H_ namespace mindspore { namespace lite { class EmbeddingLookupSparse : public PrimitiveC { public: - explicit EmbeddingLookupSparse(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit EmbeddingLookupSparse(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit EmbeddingLookupSparse(schema::Primitive *primitive) : PrimitiveC(primitive) {} std::vector GetSpIds() const; std::vector GetSpWeights() const; diff --git a/mindspore/lite/src/ops/equal.h b/mindspore/lite/src/ops/equal.h index 1acf5d87139..60fa27b9792 100644 --- a/mindspore/lite/src/ops/equal.h +++ b/mindspore/lite/src/ops/equal.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_EQUAL_H_ +#define LITE_MINDSPORE_LITE_C_OPS_EQUAL_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/arithmetic.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_EQUAL_H_ -#define LITE_MINDSPORE_LITE_C_OPS_EQUAL_H_ namespace mindspore { namespace lite { class Equal : public Arithmetic { public: - explicit Equal(OriginPrimitive *primitive) : Arithmetic(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Equal(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + explicit Equal(schema::Primitive *primitive) : Arithmetic(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/exp.h b/mindspore/lite/src/ops/exp.h index 4b6c10a75fc..fbe17b848e9 100644 --- a/mindspore/lite/src/ops/exp.h +++ b/mindspore/lite/src/ops/exp.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_EXP_H_ +#define LITE_MINDSPORE_LITE_C_OPS_EXP_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/arithmetic_self.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_EXP_H_ -#define LITE_MINDSPORE_LITE_C_OPS_EXP_H_ namespace mindspore { namespace lite { class Exp : public ArithmeticSelf { public: - explicit Exp(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Exp(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} +#endif + explicit Exp(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/expand_dims.cc b/mindspore/lite/src/ops/expand_dims.cc index 36cc4d8064a..38b584d6812 100644 --- a/mindspore/lite/src/ops/expand_dims.cc +++ b/mindspore/lite/src/ops/expand_dims.cc @@ -19,19 +19,19 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int ExpandDims::GetDim() const { return this->primitive->value.AsExpandDims()->dim; } +int ExpandDims::GetDim() const { return this->primitive_->value.AsExpandDims()->dim; } -void ExpandDims::SetDim(int dim) { this->primitive->value.AsExpandDims()->dim = dim; } +void ExpandDims::SetDim(int dim) { this->primitive_->value.AsExpandDims()->dim = dim; } #else -int ExpandDims::GetDim() const { return this->primitive->value_as_ExpandDims()->dim(); } +int ExpandDims::GetDim() const { return this->primitive_->value_as_ExpandDims()->dim(); } void ExpandDims::SetDim(int dim) {} #endif int ExpandDims::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); diff --git a/mindspore/lite/src/ops/expand_dims.h b/mindspore/lite/src/ops/expand_dims.h index 3a3aa1f6b8d..45a040cc3af 100644 --- a/mindspore/lite/src/ops/expand_dims.h +++ b/mindspore/lite/src/ops/expand_dims.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_EXPAND_DIMS_H_ +#define LITE_MINDSPORE_LITE_C_OPS_EXPAND_DIMS_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_EXPAND_DIMS_H_ -#define LITE_MINDSPORE_LITE_C_OPS_EXPAND_DIMS_H_ namespace mindspore { namespace lite { class ExpandDims : public PrimitiveC { public: - explicit ExpandDims(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit ExpandDims(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit ExpandDims(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetDim() const; diff --git a/mindspore/lite/src/ops/fake_quant_with_min_max_vars.cc b/mindspore/lite/src/ops/fake_quant_with_min_max_vars.cc index a6bf9e293b0..f027b42753f 100644 --- a/mindspore/lite/src/ops/fake_quant_with_min_max_vars.cc +++ b/mindspore/lite/src/ops/fake_quant_with_min_max_vars.cc @@ -20,24 +20,24 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE bool FakeQuantWithMinMaxVars::GetNarrowRange() const { - return this->primitive->value.AsFakeQuantWithMinMaxVars()->narrowRange; + return this->primitive_->value.AsFakeQuantWithMinMaxVars()->narrowRange; } -int FakeQuantWithMinMaxVars::GetNumBits() const { return this->primitive->value.AsFakeQuantWithMinMaxVars()->numBits; } +int FakeQuantWithMinMaxVars::GetNumBits() const { return this->primitive_->value.AsFakeQuantWithMinMaxVars()->numBits; } void FakeQuantWithMinMaxVars::SetNarrowRange(bool narrow_range) { - this->primitive->value.AsFakeQuantWithMinMaxVars()->narrowRange = narrow_range; + this->primitive_->value.AsFakeQuantWithMinMaxVars()->narrowRange = narrow_range; } void FakeQuantWithMinMaxVars::SetNumBits(int num_bits) { - this->primitive->value.AsFakeQuantWithMinMaxVars()->numBits = num_bits; + this->primitive_->value.AsFakeQuantWithMinMaxVars()->numBits = num_bits; } #else bool FakeQuantWithMinMaxVars::GetNarrowRange() const { - return this->primitive->value_as_FakeQuantWithMinMaxVars()->narrowRange(); + return this->primitive_->value_as_FakeQuantWithMinMaxVars()->narrowRange(); } int FakeQuantWithMinMaxVars::GetNumBits() const { - return this->primitive->value_as_FakeQuantWithMinMaxVars()->numBits(); + return this->primitive_->value_as_FakeQuantWithMinMaxVars()->numBits(); } void FakeQuantWithMinMaxVars::SetNarrowRange(bool narrow_range) {} diff --git a/mindspore/lite/src/ops/fake_quant_with_min_max_vars.h b/mindspore/lite/src/ops/fake_quant_with_min_max_vars.h index a8cea93ccfd..e1d3babb3b8 100644 --- a/mindspore/lite/src/ops/fake_quant_with_min_max_vars.h +++ b/mindspore/lite/src/ops/fake_quant_with_min_max_vars.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_FAKE_QUANT_WITH_MIN_MAX_VARS_H_ +#define LITE_MINDSPORE_LITE_C_OPS_FAKE_QUANT_WITH_MIN_MAX_VARS_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_FAKE_QUANT_WITH_MIN_MAX_VARS_H_ -#define LITE_MINDSPORE_LITE_C_OPS_FAKE_QUANT_WITH_MIN_MAX_VARS_H_ namespace mindspore { namespace lite { class FakeQuantWithMinMaxVars : public PrimitiveC { public: - explicit FakeQuantWithMinMaxVars(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit FakeQuantWithMinMaxVars(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit FakeQuantWithMinMaxVars(schema::Primitive *primitive) : PrimitiveC(primitive) {} bool GetNarrowRange() const; int GetNumBits() const; diff --git a/mindspore/lite/src/ops/fill.cc b/mindspore/lite/src/ops/fill.cc index 96fddefa3bd..35682a0430a 100644 --- a/mindspore/lite/src/ops/fill.cc +++ b/mindspore/lite/src/ops/fill.cc @@ -19,14 +19,14 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector Fill::GetDims() const { return this->primitive->value.AsFill()->dims; } +std::vector Fill::GetDims() const { return this->primitive_->value.AsFill()->dims; } -void Fill::SetDims(const std::vector &dims) { this->primitive->value.AsFill()->dims = dims; } +void Fill::SetDims(const std::vector &dims) { this->primitive_->value.AsFill()->dims = dims; } #else std::vector Fill::GetDims() const { - auto fb_vector = this->primitive->value_as_Fill()->dims(); + auto fb_vector = this->primitive_->value_as_Fill()->dims(); return std::vector(fb_vector->begin(), fb_vector->end()); } @@ -34,7 +34,7 @@ void Fill::SetDims(const std::vector &dims) {} #endif int Fill::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); auto output = outputs_.front(); if (input == nullptr || output == nullptr) { diff --git a/mindspore/lite/src/ops/fill.h b/mindspore/lite/src/ops/fill.h index 7850d09b8e3..e7766722162 100644 --- a/mindspore/lite/src/ops/fill.h +++ b/mindspore/lite/src/ops/fill.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_FILL_H_ +#define LITE_MINDSPORE_LITE_C_OPS_FILL_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_FILL_H_ -#define LITE_MINDSPORE_LITE_C_OPS_FILL_H_ namespace mindspore { namespace lite { class Fill : public PrimitiveC { public: - explicit Fill(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Fill(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Fill(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetDims() const; diff --git a/mindspore/lite/src/ops/flatten.cc b/mindspore/lite/src/ops/flatten.cc index b9259b8c4f1..8298824334c 100644 --- a/mindspore/lite/src/ops/flatten.cc +++ b/mindspore/lite/src/ops/flatten.cc @@ -20,7 +20,7 @@ namespace mindspore { namespace lite { int Flatten::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); auto output = outputs_.front(); if (input == nullptr || output == nullptr) { diff --git a/mindspore/lite/src/ops/flatten.h b/mindspore/lite/src/ops/flatten.h index 0c0023809c5..1ab44c577f0 100644 --- a/mindspore/lite/src/ops/flatten.h +++ b/mindspore/lite/src/ops/flatten.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_FLATTEN_H_ +#define LITE_MINDSPORE_LITE_C_OPS_FLATTEN_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_FLATTEN_H_ -#define LITE_MINDSPORE_LITE_C_OPS_FLATTEN_H_ namespace mindspore { namespace lite { class Flatten : public PrimitiveC { public: - explicit Flatten(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Flatten(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Flatten(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; }; diff --git a/mindspore/lite/src/ops/floor.h b/mindspore/lite/src/ops/floor.h index 7f61378c77d..dc13935a980 100644 --- a/mindspore/lite/src/ops/floor.h +++ b/mindspore/lite/src/ops/floor.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_FLOOR_H_ +#define LITE_MINDSPORE_LITE_C_OPS_FLOOR_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/arithmetic_self.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_FLOOR_H_ -#define LITE_MINDSPORE_LITE_C_OPS_FLOOR_H_ +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { class Floor : public ArithmeticSelf { public: - explicit Floor(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Floor(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} +#endif + explicit Floor(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/floor_div.h b/mindspore/lite/src/ops/floor_div.h index 4b187426b4e..2920e0f666c 100644 --- a/mindspore/lite/src/ops/floor_div.h +++ b/mindspore/lite/src/ops/floor_div.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_FLOOR_DIV_H_ +#define LITE_MINDSPORE_LITE_C_OPS_FLOOR_DIV_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/arithmetic.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_FLOOR_DIV_H_ -#define LITE_MINDSPORE_LITE_C_OPS_FLOOR_DIV_H_ namespace mindspore { namespace lite { class FloorDiv : public Arithmetic { public: - explicit FloorDiv(OriginPrimitive *primitive) : Arithmetic(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit FloorDiv(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + explicit FloorDiv(schema::Primitive *primitive) : Arithmetic(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/floor_mod.h b/mindspore/lite/src/ops/floor_mod.h index d9e78e10b7f..294b1b34ffb 100644 --- a/mindspore/lite/src/ops/floor_mod.h +++ b/mindspore/lite/src/ops/floor_mod.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_FLOOR_MOD_H_ +#define LITE_MINDSPORE_LITE_C_OPS_FLOOR_MOD_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/arithmetic.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_FLOOR_MOD_H_ -#define LITE_MINDSPORE_LITE_C_OPS_FLOOR_MOD_H_ namespace mindspore { namespace lite { class FloorMod : public Arithmetic { public: - explicit FloorMod(OriginPrimitive *primitive) : Arithmetic(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit FloorMod(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + explicit FloorMod(schema::Primitive *primitive) : Arithmetic(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/full_connection.cc b/mindspore/lite/src/ops/full_connection.cc index 70dc056da79..35d4c505fac 100644 --- a/mindspore/lite/src/ops/full_connection.cc +++ b/mindspore/lite/src/ops/full_connection.cc @@ -19,23 +19,23 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -bool FullConnection::GetHasBias() const { return this->primitive->value.AsFullConnection()->hasBias; } -int FullConnection::GetAxis() const { return this->primitive->value.AsFullConnection()->axis; } -bool FullConnection::GetUseAxis() const { return this->primitive->value.AsFullConnection()->useAxis; } -int FullConnection::GetActivationType() const { return this->primitive->value.AsFullConnection()->activationType; } +bool FullConnection::GetHasBias() const { return this->primitive_->value.AsFullConnection()->hasBias; } +int FullConnection::GetAxis() const { return this->primitive_->value.AsFullConnection()->axis; } +bool FullConnection::GetUseAxis() const { return this->primitive_->value.AsFullConnection()->useAxis; } +int FullConnection::GetActivationType() const { return this->primitive_->value.AsFullConnection()->activationType; } -void FullConnection::SetHasBias(bool has_bias) { this->primitive->value.AsFullConnection()->hasBias = has_bias; } -void FullConnection::SetAxis(int axis) { this->primitive->value.AsFullConnection()->axis = axis; } -void FullConnection::SetUseAxis(bool use_axis) { this->primitive->value.AsFullConnection()->useAxis = use_axis; } +void FullConnection::SetHasBias(bool has_bias) { this->primitive_->value.AsFullConnection()->hasBias = has_bias; } +void FullConnection::SetAxis(int axis) { this->primitive_->value.AsFullConnection()->axis = axis; } +void FullConnection::SetUseAxis(bool use_axis) { this->primitive_->value.AsFullConnection()->useAxis = use_axis; } void FullConnection::SetActivationType(int activationType) { - this->primitive->value.AsFullConnection()->activationType = (schema::ActivationType)activationType; + this->primitive_->value.AsFullConnection()->activationType = (schema::ActivationType)activationType; } #else -bool FullConnection::GetHasBias() const { return this->primitive->value_as_FullConnection()->hasBias(); } -int FullConnection::GetAxis() const { return this->primitive->value_as_FullConnection()->axis(); } -bool FullConnection::GetUseAxis() const { return this->primitive->value_as_FullConnection()->useAxis(); } -int FullConnection::GetActivationType() const { return this->primitive->value_as_FullConnection()->activationType(); } +bool FullConnection::GetHasBias() const { return this->primitive_->value_as_FullConnection()->hasBias(); } +int FullConnection::GetAxis() const { return this->primitive_->value_as_FullConnection()->axis(); } +bool FullConnection::GetUseAxis() const { return this->primitive_->value_as_FullConnection()->useAxis(); } +int FullConnection::GetActivationType() const { return this->primitive_->value_as_FullConnection()->activationType(); } void FullConnection::SetHasBias(bool has_bias) {} void FullConnection::SetAxis(int axis) {} @@ -44,7 +44,7 @@ void FullConnection::SetActivationType(int activationType) {} #endif int FullConnection::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input0 = inputs_.front(); MS_ASSERT(input0 != nullptr); auto input1 = inputs_[1]; diff --git a/mindspore/lite/src/ops/full_connection.h b/mindspore/lite/src/ops/full_connection.h index 5c971bac787..4ed808bbaf9 100644 --- a/mindspore/lite/src/ops/full_connection.h +++ b/mindspore/lite/src/ops/full_connection.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_FULL_CONNECTION_H_ +#define LITE_MINDSPORE_LITE_C_OPS_FULL_CONNECTION_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_FULL_CONNECTION_H_ -#define LITE_MINDSPORE_LITE_C_OPS_FULL_CONNECTION_H_ namespace mindspore { namespace lite { class FullConnection : public PrimitiveC { public: - explicit FullConnection(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit FullConnection(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit FullConnection(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; bool GetHasBias() const; diff --git a/mindspore/lite/src/ops/fused_batchnorm.cc b/mindspore/lite/src/ops/fused_batchnorm.cc index 5e97cef7d18..ecb7e0c1fec 100644 --- a/mindspore/lite/src/ops/fused_batchnorm.cc +++ b/mindspore/lite/src/ops/fused_batchnorm.cc @@ -19,19 +19,19 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -float FusedBatchNorm::GetEpsilon() const { return this->primitive->value.AsFusedBatchNorm()->epsilon; } -float FusedBatchNorm::GetMomentum() const { return this->primitive->value.AsFusedBatchNorm()->momentum; } -int FusedBatchNorm::GetSpatial() const { return this->primitive->value.AsFusedBatchNorm()->spatial; } +float FusedBatchNorm::GetEpsilon() const { return this->primitive_->value.AsFusedBatchNorm()->epsilon; } +float FusedBatchNorm::GetMomentum() const { return this->primitive_->value.AsFusedBatchNorm()->momentum; } +int FusedBatchNorm::GetSpatial() const { return this->primitive_->value.AsFusedBatchNorm()->spatial; } -void FusedBatchNorm::SetEpsilon(float epsilon) { this->primitive->value.AsFusedBatchNorm()->epsilon = epsilon; } -void FusedBatchNorm::SetMomentum(float momentum) { this->primitive->value.AsFusedBatchNorm()->momentum = momentum; } -void FusedBatchNorm::SetSpatial(int spatial) { this->primitive->value.AsFusedBatchNorm()->spatial = spatial; } +void FusedBatchNorm::SetEpsilon(float epsilon) { this->primitive_->value.AsFusedBatchNorm()->epsilon = epsilon; } +void FusedBatchNorm::SetMomentum(float momentum) { this->primitive_->value.AsFusedBatchNorm()->momentum = momentum; } +void FusedBatchNorm::SetSpatial(int spatial) { this->primitive_->value.AsFusedBatchNorm()->spatial = spatial; } #else -float FusedBatchNorm::GetEpsilon() const { return this->primitive->value_as_FusedBatchNorm()->epsilon(); } -float FusedBatchNorm::GetMomentum() const { return this->primitive->value_as_FusedBatchNorm()->momentum(); } -int FusedBatchNorm::GetSpatial() const { return this->primitive->value_as_FusedBatchNorm()->spatial(); } +float FusedBatchNorm::GetEpsilon() const { return this->primitive_->value_as_FusedBatchNorm()->epsilon(); } +float FusedBatchNorm::GetMomentum() const { return this->primitive_->value_as_FusedBatchNorm()->momentum(); } +int FusedBatchNorm::GetSpatial() const { return this->primitive_->value_as_FusedBatchNorm()->spatial(); } void FusedBatchNorm::SetEpsilon(float epsilon) {} void FusedBatchNorm::SetMomentum(float momentum) {} diff --git a/mindspore/lite/src/ops/fused_batchnorm.h b/mindspore/lite/src/ops/fused_batchnorm.h index 17381fdaa55..dd96a7e533d 100644 --- a/mindspore/lite/src/ops/fused_batchnorm.h +++ b/mindspore/lite/src/ops/fused_batchnorm.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_FUSED_BATCH_NORM_H_ +#define LITE_MINDSPORE_LITE_C_OPS_FUSED_BATCH_NORM_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_FUSED_BATCH_NORM_H_ -#define LITE_MINDSPORE_LITE_C_OPS_FUSED_BATCH_NORM_H_ namespace mindspore { namespace lite { class FusedBatchNorm : public PrimitiveC { public: - explicit FusedBatchNorm(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit FusedBatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit FusedBatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {} float GetEpsilon() const; float GetMomentum() const; diff --git a/mindspore/lite/src/ops/gather.cc b/mindspore/lite/src/ops/gather.cc index 4e8bac1fa88..d4546da0580 100644 --- a/mindspore/lite/src/ops/gather.cc +++ b/mindspore/lite/src/ops/gather.cc @@ -22,23 +22,23 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Gather::GetAxis() const { return this->primitive->value.AsGather()->axis; } -int Gather::GetBatchDims() const { return this->primitive->value.AsGather()->batchDims; } +int Gather::GetAxis() const { return this->primitive_->value.AsGather()->axis; } +int Gather::GetBatchDims() const { return this->primitive_->value.AsGather()->batchDims; } -void Gather::SetAxis(int axis) { this->primitive->value.AsGather()->axis = axis; } -void Gather::SetBatchDims(int batch_dims) { this->primitive->value.AsGather()->batchDims = batch_dims; } +void Gather::SetAxis(int axis) { this->primitive_->value.AsGather()->axis = axis; } +void Gather::SetBatchDims(int batch_dims) { this->primitive_->value.AsGather()->batchDims = batch_dims; } #else -int Gather::GetAxis() const { return this->primitive->value_as_Gather()->axis(); } -int Gather::GetBatchDims() const { return this->primitive->value_as_Gather()->batchDims(); } +int Gather::GetAxis() const { return this->primitive_->value_as_Gather()->axis(); } +int Gather::GetBatchDims() const { return this->primitive_->value_as_Gather()->batchDims(); } void Gather::SetAxis(int axis) {} void Gather::SetBatchDims(int batch_dims) {} #endif int Gather::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); if (inputs_.size() != kDoubleNum) { MS_LOG(ERROR) << "Gather should have two inputs"; return RET_INPUT_TENSOR_ERROR; diff --git a/mindspore/lite/src/ops/gather.h b/mindspore/lite/src/ops/gather.h index a53b0de3194..dac98f8ca06 100644 --- a/mindspore/lite/src/ops/gather.h +++ b/mindspore/lite/src/ops/gather.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_GATHER_H_ +#define LITE_MINDSPORE_LITE_C_OPS_GATHER_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_GATHER_H_ -#define LITE_MINDSPORE_LITE_C_OPS_GATHER_H_ namespace mindspore { namespace lite { class Gather : public PrimitiveC { public: - explicit Gather(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Gather(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Gather(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; diff --git a/mindspore/lite/src/ops/gather_nd.cc b/mindspore/lite/src/ops/gather_nd.cc index a6696bfefe8..e5da4346cba 100644 --- a/mindspore/lite/src/ops/gather_nd.cc +++ b/mindspore/lite/src/ops/gather_nd.cc @@ -19,19 +19,19 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int GatherNd::GetBatchDims() const { return this->primitive->value.AsGatherNd()->batchDims; } +int GatherNd::GetBatchDims() const { return this->primitive_->value.AsGatherNd()->batchDims; } -void GatherNd::SetBatchDims(int batch_dims) { this->primitive->value.AsGatherNd()->batchDims = batch_dims; } +void GatherNd::SetBatchDims(int batch_dims) { this->primitive_->value.AsGatherNd()->batchDims = batch_dims; } #else -int GatherNd::GetBatchDims() const { return this->primitive->value_as_GatherNd()->batchDims(); } +int GatherNd::GetBatchDims() const { return this->primitive_->value_as_GatherNd()->batchDims(); } void GatherNd::SetBatchDims(int batch_dims) {} #endif int GatherNd::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); if (inputs_.size() != kDoubleNum) { MS_LOG(ERROR) << "GatherNd should have two inputs"; return RET_INPUT_TENSOR_ERROR; diff --git a/mindspore/lite/src/ops/gather_nd.h b/mindspore/lite/src/ops/gather_nd.h index 1016e687175..43d8f963db6 100644 --- a/mindspore/lite/src/ops/gather_nd.h +++ b/mindspore/lite/src/ops/gather_nd.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_GATHER_ND_H_ +#define LITE_MINDSPORE_LITE_C_OPS_GATHER_ND_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_GATHER_ND_H_ -#define LITE_MINDSPORE_LITE_C_OPS_GATHER_ND_H_ namespace mindspore { namespace lite { class GatherNd : public PrimitiveC { public: - explicit GatherNd(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit GatherNd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit GatherNd(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetBatchDims() const; diff --git a/mindspore/lite/src/ops/greater.h b/mindspore/lite/src/ops/greater.h index 059eb53bbf4..f959deb966c 100644 --- a/mindspore/lite/src/ops/greater.h +++ b/mindspore/lite/src/ops/greater.h @@ -13,26 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_GREATER_H_ +#define LITE_MINDSPORE_LITE_C_OPS_GREATER_H_ #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/arithmetic.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_GREATER_H_ -#define LITE_MINDSPORE_LITE_C_OPS_GREATER_H_ namespace mindspore { namespace lite { class Greater : public Arithmetic { public: - explicit Greater(OriginPrimitive *primitive) : Arithmetic(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Greater(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + explicit Greater(schema::Primitive *primitive) : Arithmetic(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/greater_equal.h b/mindspore/lite/src/ops/greater_equal.h index 1cbacd9bd06..72d74689f1f 100644 --- a/mindspore/lite/src/ops/greater_equal.h +++ b/mindspore/lite/src/ops/greater_equal.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_GREATER_EQUAL_H_ +#define LITE_MINDSPORE_LITE_C_OPS_GREATER_EQUAL_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/arithmetic.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_GREATER_EQUAL_H_ -#define LITE_MINDSPORE_LITE_C_OPS_GREATER_EQUAL_H_ namespace mindspore { namespace lite { class GreaterEqual : public Arithmetic { public: - explicit GreaterEqual(OriginPrimitive *primitive) : Arithmetic(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit GreaterEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + explicit GreaterEqual(schema::Primitive *primitive) : Arithmetic(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/l2_norm.cc b/mindspore/lite/src/ops/l2_norm.cc index c9cbe584dd0..15bbb6713a2 100644 --- a/mindspore/lite/src/ops/l2_norm.cc +++ b/mindspore/lite/src/ops/l2_norm.cc @@ -19,19 +19,19 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector L2Norm::GetAxis() const { return this->primitive->value.AsL2Norm()->axis; } -float L2Norm::GetEpsilon() const { return this->primitive->value.AsL2Norm()->epsilon; } +std::vector L2Norm::GetAxis() const { return this->primitive_->value.AsL2Norm()->axis; } +float L2Norm::GetEpsilon() const { return this->primitive_->value.AsL2Norm()->epsilon; } -void L2Norm::SetAxis(const std::vector &axis) { this->primitive->value.AsL2Norm()->axis = axis; } -void L2Norm::SetEpsilon(float epsilon) { this->primitive->value.AsL2Norm()->epsilon = epsilon; } +void L2Norm::SetAxis(const std::vector &axis) { this->primitive_->value.AsL2Norm()->axis = axis; } +void L2Norm::SetEpsilon(float epsilon) { this->primitive_->value.AsL2Norm()->epsilon = epsilon; } #else std::vector L2Norm::GetAxis() const { - auto fb_vector = this->primitive->value_as_L2Norm()->axis(); + auto fb_vector = this->primitive_->value_as_L2Norm()->axis(); return std::vector(fb_vector->begin(), fb_vector->end()); } -float L2Norm::GetEpsilon() const { return this->primitive->value_as_L2Norm()->epsilon(); } +float L2Norm::GetEpsilon() const { return this->primitive_->value_as_L2Norm()->epsilon(); } void L2Norm::SetAxis(const std::vector &axis) {} void L2Norm::SetEpsilon(float epsilon) {} diff --git a/mindspore/lite/src/ops/l2_norm.h b/mindspore/lite/src/ops/l2_norm.h index 912795b5bab..fe4a242fbe3 100644 --- a/mindspore/lite/src/ops/l2_norm.h +++ b/mindspore/lite/src/ops/l2_norm.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_L2_NORM_H_ +#define LITE_MINDSPORE_LITE_C_OPS_L2_NORM_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_L2_NORM_H_ -#define LITE_MINDSPORE_LITE_C_OPS_L2_NORM_H_ namespace mindspore { namespace lite { class L2Norm : public PrimitiveC { public: - explicit L2Norm(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit L2Norm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit L2Norm(schema::Primitive *primitive) : PrimitiveC(primitive) {} std::vector GetAxis() const; float GetEpsilon() const; diff --git a/mindspore/lite/src/ops/leaky_relu.cc b/mindspore/lite/src/ops/leaky_relu.cc index e0850f50a9c..f7d2a4bf111 100644 --- a/mindspore/lite/src/ops/leaky_relu.cc +++ b/mindspore/lite/src/ops/leaky_relu.cc @@ -19,15 +19,15 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -float LeakyReLU::GetNegativeSlope() const { return this->primitive->value.AsLeakyReLU()->negativeSlope; } +float LeakyReLU::GetNegativeSlope() const { return this->primitive_->value.AsLeakyReLU()->negativeSlope; } void LeakyReLU::SetNegativeSlope(float negative_slope) { - this->primitive->value.AsLeakyReLU()->negativeSlope = negative_slope; + this->primitive_->value.AsLeakyReLU()->negativeSlope = negative_slope; } #else -float LeakyReLU::GetNegativeSlope() const { return this->primitive->value_as_LeakyReLU()->negativeSlope(); } +float LeakyReLU::GetNegativeSlope() const { return this->primitive_->value_as_LeakyReLU()->negativeSlope(); } void LeakyReLU::SetNegativeSlope(float negative_slope) {} #endif diff --git a/mindspore/lite/src/ops/leaky_relu.h b/mindspore/lite/src/ops/leaky_relu.h index 4cdaa740906..02469ec5b89 100644 --- a/mindspore/lite/src/ops/leaky_relu.h +++ b/mindspore/lite/src/ops/leaky_relu.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_LEAKY_RE_L_U_H_ +#define LITE_MINDSPORE_LITE_C_OPS_LEAKY_RE_L_U_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_LEAKY_RE_L_U_H_ -#define LITE_MINDSPORE_LITE_C_OPS_LEAKY_RE_L_U_H_ namespace mindspore { namespace lite { class LeakyReLU : public PrimitiveC { public: - explicit LeakyReLU(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit LeakyReLU(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit LeakyReLU(schema::Primitive *primitive) : PrimitiveC(primitive) {} float GetNegativeSlope() const; void SetNegativeSlope(float negative_slope); diff --git a/mindspore/lite/src/ops/less.h b/mindspore/lite/src/ops/less.h index 58c359735be..a5ccda4e071 100644 --- a/mindspore/lite/src/ops/less.h +++ b/mindspore/lite/src/ops/less.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_LESS_H_ +#define LITE_MINDSPORE_LITE_C_OPS_LESS_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/arithmetic.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_LESS_H_ -#define LITE_MINDSPORE_LITE_C_OPS_LESS_H_ namespace mindspore { namespace lite { class Less : public Arithmetic { public: - explicit Less(OriginPrimitive *primitive) : Arithmetic(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Less(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + explicit Less(schema::Primitive *primitive) : Arithmetic(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/less_equal.h b/mindspore/lite/src/ops/less_equal.h index cd1308dcd60..52cdf561689 100644 --- a/mindspore/lite/src/ops/less_equal.h +++ b/mindspore/lite/src/ops/less_equal.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_LESS_EQUAL_H_ +#define LITE_MINDSPORE_LITE_C_OPS_LESS_EQUAL_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/arithmetic.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_LESS_EQUAL_H_ -#define LITE_MINDSPORE_LITE_C_OPS_LESS_EQUAL_H_ namespace mindspore { namespace lite { class LessEqual : public Arithmetic { public: - explicit LessEqual(OriginPrimitive *primitive) : Arithmetic(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit LessEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + explicit LessEqual(schema::Primitive *primitive) : Arithmetic(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/local_response_normalization.cc b/mindspore/lite/src/ops/local_response_normalization.cc index 2e6dc944c68..891f165e6e7 100644 --- a/mindspore/lite/src/ops/local_response_normalization.cc +++ b/mindspore/lite/src/ops/local_response_normalization.cc @@ -20,44 +20,44 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE int LocalResponseNormalization::GetDepthRadius() const { - return this->primitive->value.AsLocalResponseNormalization()->depth_radius; + return this->primitive_->value.AsLocalResponseNormalization()->depth_radius; } float LocalResponseNormalization::GetBias() const { - return this->primitive->value.AsLocalResponseNormalization()->bias; + return this->primitive_->value.AsLocalResponseNormalization()->bias; } float LocalResponseNormalization::GetAlpha() const { - return this->primitive->value.AsLocalResponseNormalization()->alpha; + return this->primitive_->value.AsLocalResponseNormalization()->alpha; } float LocalResponseNormalization::GetBeta() const { - return this->primitive->value.AsLocalResponseNormalization()->beta; + return this->primitive_->value.AsLocalResponseNormalization()->beta; } void LocalResponseNormalization::SetDepthRadius(int depth_radius) { - this->primitive->value.AsLocalResponseNormalization()->depth_radius = depth_radius; + this->primitive_->value.AsLocalResponseNormalization()->depth_radius = depth_radius; } void LocalResponseNormalization::SetBias(float bias) { - this->primitive->value.AsLocalResponseNormalization()->bias = bias; + this->primitive_->value.AsLocalResponseNormalization()->bias = bias; } void LocalResponseNormalization::SetAlpha(float alpha) { - this->primitive->value.AsLocalResponseNormalization()->alpha = alpha; + this->primitive_->value.AsLocalResponseNormalization()->alpha = alpha; } void LocalResponseNormalization::SetBeta(float beta) { - this->primitive->value.AsLocalResponseNormalization()->beta = beta; + this->primitive_->value.AsLocalResponseNormalization()->beta = beta; } #else int LocalResponseNormalization::GetDepthRadius() const { - return this->primitive->value_as_LocalResponseNormalization()->depth_radius(); + return this->primitive_->value_as_LocalResponseNormalization()->depth_radius(); } float LocalResponseNormalization::GetBias() const { - return this->primitive->value_as_LocalResponseNormalization()->bias(); + return this->primitive_->value_as_LocalResponseNormalization()->bias(); } float LocalResponseNormalization::GetAlpha() const { - return this->primitive->value_as_LocalResponseNormalization()->alpha(); + return this->primitive_->value_as_LocalResponseNormalization()->alpha(); } float LocalResponseNormalization::GetBeta() const { - return this->primitive->value_as_LocalResponseNormalization()->beta(); + return this->primitive_->value_as_LocalResponseNormalization()->beta(); } void LocalResponseNormalization::SetDepthRadius(int depth_radius) {} diff --git a/mindspore/lite/src/ops/local_response_normalization.h b/mindspore/lite/src/ops/local_response_normalization.h index f099967195a..7b19e08961b 100644 --- a/mindspore/lite/src/ops/local_response_normalization.h +++ b/mindspore/lite/src/ops/local_response_normalization.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_LOCAL_RESPONSE_NORMALIZATION_H_ +#define LITE_MINDSPORE_LITE_C_OPS_LOCAL_RESPONSE_NORMALIZATION_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_LOCAL_RESPONSE_NORMALIZATION_H_ -#define LITE_MINDSPORE_LITE_C_OPS_LOCAL_RESPONSE_NORMALIZATION_H_ namespace mindspore { namespace lite { class LocalResponseNormalization : public PrimitiveC { public: - explicit LocalResponseNormalization(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit LocalResponseNormalization(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit LocalResponseNormalization(schema::Primitive *primitive) : PrimitiveC(primitive) {} int GetDepthRadius() const; float GetBias() const; diff --git a/mindspore/lite/src/ops/log.h b/mindspore/lite/src/ops/log.h index d8a05dd2682..88af6a906a7 100644 --- a/mindspore/lite/src/ops/log.h +++ b/mindspore/lite/src/ops/log.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_LOG_H_ +#define LITE_MINDSPORE_LITE_C_OPS_LOG_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/arithmetic_self.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_LOG_H_ -#define LITE_MINDSPORE_LITE_C_OPS_LOG_H_ namespace mindspore { namespace lite { class Log : public ArithmeticSelf { public: - explicit Log(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Log(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} +#endif + explicit Log(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/logical_and.h b/mindspore/lite/src/ops/logical_and.h index 4dea085e5d9..5cbacaf63b3 100644 --- a/mindspore/lite/src/ops/logical_and.h +++ b/mindspore/lite/src/ops/logical_and.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_LOGICAL_AND_H_ +#define LITE_MINDSPORE_LITE_C_OPS_LOGICAL_AND_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/arithmetic.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_LOGICAL_AND_H_ -#define LITE_MINDSPORE_LITE_C_OPS_LOGICAL_AND_H_ +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { class LogicalAnd : public Arithmetic { public: - explicit LogicalAnd(OriginPrimitive *primitive) : Arithmetic(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit LogicalAnd(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + explicit LogicalAnd(schema::Primitive *primitive) : Arithmetic(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/logical_not.h b/mindspore/lite/src/ops/logical_not.h index 675d4906053..c1cf519a977 100644 --- a/mindspore/lite/src/ops/logical_not.h +++ b/mindspore/lite/src/ops/logical_not.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_LOGICAL_NOT_H_ +#define LITE_MINDSPORE_LITE_C_OPS_LOGICAL_NOT_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/arithmetic_self.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_LOGICAL_NOT_H_ -#define LITE_MINDSPORE_LITE_C_OPS_LOGICAL_NOT_H_ +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { class LogicalNot : public ArithmeticSelf { public: - explicit LogicalNot(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit LogicalNot(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} +#endif + explicit LogicalNot(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/logical_or.h b/mindspore/lite/src/ops/logical_or.h index 158cb367c02..327a164d549 100644 --- a/mindspore/lite/src/ops/logical_or.h +++ b/mindspore/lite/src/ops/logical_or.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_LOGICAL_OR_H_ +#define LITE_MINDSPORE_LITE_C_OPS_LOGICAL_OR_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/arithmetic.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_LOGICAL_OR_H_ -#define LITE_MINDSPORE_LITE_C_OPS_LOGICAL_OR_H_ +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { class LogicalOr : public Arithmetic { public: - explicit LogicalOr(OriginPrimitive *primitive) : Arithmetic(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit LogicalOr(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + explicit LogicalOr(schema::Primitive *primitive) : Arithmetic(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/lrn.cc b/mindspore/lite/src/ops/lrn.cc index eb9a8fcd231..859b8728ed8 100644 --- a/mindspore/lite/src/ops/lrn.cc +++ b/mindspore/lite/src/ops/lrn.cc @@ -19,22 +19,22 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -float Lrn::GetAlpha() const { return this->primitive->value.AsLrn()->alpha; } -float Lrn::GetBeta() const { return this->primitive->value.AsLrn()->beta; } -float Lrn::GetBias() const { return this->primitive->value.AsLrn()->bias; } -int Lrn::GetSize() const { return this->primitive->value.AsLrn()->size; } +float Lrn::GetAlpha() const { return this->primitive_->value.AsLrn()->alpha; } +float Lrn::GetBeta() const { return this->primitive_->value.AsLrn()->beta; } +float Lrn::GetBias() const { return this->primitive_->value.AsLrn()->bias; } +int Lrn::GetSize() const { return this->primitive_->value.AsLrn()->size; } -void Lrn::SetAlpha(float alpha) { this->primitive->value.AsLrn()->alpha = alpha; } -void Lrn::SetBeta(float beta) { this->primitive->value.AsLrn()->beta = beta; } -void Lrn::SetBias(float bias) { this->primitive->value.AsLrn()->bias = bias; } -void Lrn::SetSize(int size) { this->primitive->value.AsLrn()->size = size; } +void Lrn::SetAlpha(float alpha) { this->primitive_->value.AsLrn()->alpha = alpha; } +void Lrn::SetBeta(float beta) { this->primitive_->value.AsLrn()->beta = beta; } +void Lrn::SetBias(float bias) { this->primitive_->value.AsLrn()->bias = bias; } +void Lrn::SetSize(int size) { this->primitive_->value.AsLrn()->size = size; } #else -float Lrn::GetAlpha() const { return this->primitive->value_as_Lrn()->alpha(); } -float Lrn::GetBeta() const { return this->primitive->value_as_Lrn()->beta(); } -float Lrn::GetBias() const { return this->primitive->value_as_Lrn()->bias(); } -int Lrn::GetSize() const { return this->primitive->value_as_Lrn()->size(); } +float Lrn::GetAlpha() const { return this->primitive_->value_as_Lrn()->alpha(); } +float Lrn::GetBeta() const { return this->primitive_->value_as_Lrn()->beta(); } +float Lrn::GetBias() const { return this->primitive_->value_as_Lrn()->bias(); } +int Lrn::GetSize() const { return this->primitive_->value_as_Lrn()->size(); } void Lrn::SetAlpha(float alpha) {} void Lrn::SetBeta(float beta) {} diff --git a/mindspore/lite/src/ops/lrn.h b/mindspore/lite/src/ops/lrn.h index cc91453882e..8f718114285 100644 --- a/mindspore/lite/src/ops/lrn.h +++ b/mindspore/lite/src/ops/lrn.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_LRN_H_ +#define LITE_MINDSPORE_LITE_C_OPS_LRN_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_LRN_H_ -#define LITE_MINDSPORE_LITE_C_OPS_LRN_H_ namespace mindspore { namespace lite { class Lrn : public PrimitiveC { public: - explicit Lrn(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Lrn(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Lrn(schema::Primitive *primitive) : PrimitiveC(primitive) {} float GetAlpha() const; float GetBeta() const; diff --git a/mindspore/lite/src/ops/lstm.cc b/mindspore/lite/src/ops/lstm.cc index 8c6b27f7ba2..7e997c2fefd 100644 --- a/mindspore/lite/src/ops/lstm.cc +++ b/mindspore/lite/src/ops/lstm.cc @@ -19,13 +19,13 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -bool Lstm::GetBidirection() const { return this->primitive->value.AsLstm()->bidirection; } +bool Lstm::GetBidirection() const { return this->primitive_->value.AsLstm()->bidirection; } -void Lstm::SetBidirection(bool bidirection) { this->primitive->value.AsLstm()->bidirection = bidirection; } +void Lstm::SetBidirection(bool bidirection) { this->primitive_->value.AsLstm()->bidirection = bidirection; } #else -bool Lstm::GetBidirection() const { return this->primitive->value_as_Lstm()->bidirection(); } +bool Lstm::GetBidirection() const { return this->primitive_->value_as_Lstm()->bidirection(); } void Lstm::SetBidirection(bool bidirection) {} #endif @@ -33,7 +33,7 @@ void Lstm::SetBidirection(bool bidirection) {} const int kLstmInputNum = 6; const int kLstmOutputNum = 3; int Lstm::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); if (inputs_.size() != kLstmInputNum || outputs_.size() != kLstmOutputNum) { MS_LOG(ERROR) << "OpLstm inputs or outputs size error."; return RET_INPUT_TENSOR_ERROR; diff --git a/mindspore/lite/src/ops/lstm.h b/mindspore/lite/src/ops/lstm.h index 6ae0020d462..47adedff167 100644 --- a/mindspore/lite/src/ops/lstm.h +++ b/mindspore/lite/src/ops/lstm.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_LSTM_H_ +#define LITE_MINDSPORE_LITE_C_OPS_LSTM_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_LSTM_H_ -#define LITE_MINDSPORE_LITE_C_OPS_LSTM_H_ namespace mindspore { namespace lite { class Lstm : public PrimitiveC { public: - explicit Lstm(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Lstm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Lstm(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; bool GetBidirection() const; diff --git a/mindspore/lite/src/ops/matmul.cc b/mindspore/lite/src/ops/matmul.cc index 51cc51d3273..a68de23fb71 100644 --- a/mindspore/lite/src/ops/matmul.cc +++ b/mindspore/lite/src/ops/matmul.cc @@ -20,23 +20,23 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -bool MatMul::GetTransposeA() const { return this->primitive->value.AsMatMul()->transposeA; } -bool MatMul::GetTransposeB() const { return this->primitive->value.AsMatMul()->transposeB; } +bool MatMul::GetTransposeA() const { return this->primitive_->value.AsMatMul()->transposeA; } +bool MatMul::GetTransposeB() const { return this->primitive_->value.AsMatMul()->transposeB; } -void MatMul::SetTransposeA(bool transpose_a) { this->primitive->value.AsMatMul()->transposeA = transpose_a; } -void MatMul::SetTransposeB(bool transpose_b) { this->primitive->value.AsMatMul()->transposeB = transpose_b; } +void MatMul::SetTransposeA(bool transpose_a) { this->primitive_->value.AsMatMul()->transposeA = transpose_a; } +void MatMul::SetTransposeB(bool transpose_b) { this->primitive_->value.AsMatMul()->transposeB = transpose_b; } #else -bool MatMul::GetTransposeA() const { return this->primitive->value_as_MatMul()->transposeA(); } -bool MatMul::GetTransposeB() const { return this->primitive->value_as_MatMul()->transposeB(); } +bool MatMul::GetTransposeA() const { return this->primitive_->value_as_MatMul()->transposeA(); } +bool MatMul::GetTransposeB() const { return this->primitive_->value_as_MatMul()->transposeB(); } void MatMul::SetTransposeA(bool transpose_a) {} void MatMul::SetTransposeB(bool transpose_b) {} #endif int MatMul::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input0 = inputs_.front(); MS_ASSERT(input0 != nullptr); auto input1 = inputs_.at(1); diff --git a/mindspore/lite/src/ops/matmul.h b/mindspore/lite/src/ops/matmul.h index d079cf587f5..c25f4d9b938 100644 --- a/mindspore/lite/src/ops/matmul.h +++ b/mindspore/lite/src/ops/matmul.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_MAT_MUL_H_ +#define LITE_MINDSPORE_LITE_C_OPS_MAT_MUL_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_MAT_MUL_H_ -#define LITE_MINDSPORE_LITE_C_OPS_MAT_MUL_H_ namespace mindspore { namespace lite { class MatMul : public PrimitiveC { public: - explicit MatMul(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit MatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit MatMul(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; bool GetTransposeA() const; diff --git a/mindspore/lite/src/ops/matrix_diag.cc b/mindspore/lite/src/ops/matrix_diag.cc index 0f329e11790..dbc21ce8359 100644 --- a/mindspore/lite/src/ops/matrix_diag.cc +++ b/mindspore/lite/src/ops/matrix_diag.cc @@ -19,24 +19,24 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int MatrixDiag::GetK() const { return this->primitive->value.AsMatrixDiag()->k; } -int MatrixDiag::GetNumRows() const { return this->primitive->value.AsMatrixDiag()->numRows; } -int MatrixDiag::GetNumCols() const { return this->primitive->value.AsMatrixDiag()->numCols; } -float MatrixDiag::GetPaddingValue() const { return this->primitive->value.AsMatrixDiag()->paddingValue; } +int MatrixDiag::GetK() const { return this->primitive_->value.AsMatrixDiag()->k; } +int MatrixDiag::GetNumRows() const { return this->primitive_->value.AsMatrixDiag()->numRows; } +int MatrixDiag::GetNumCols() const { return this->primitive_->value.AsMatrixDiag()->numCols; } +float MatrixDiag::GetPaddingValue() const { return this->primitive_->value.AsMatrixDiag()->paddingValue; } -void MatrixDiag::SetK(int k) { this->primitive->value.AsMatrixDiag()->k = k; } -void MatrixDiag::SetNumRows(int num_rows) { this->primitive->value.AsMatrixDiag()->numRows = num_rows; } -void MatrixDiag::SetNumCols(int num_cols) { this->primitive->value.AsMatrixDiag()->numCols = num_cols; } +void MatrixDiag::SetK(int k) { this->primitive_->value.AsMatrixDiag()->k = k; } +void MatrixDiag::SetNumRows(int num_rows) { this->primitive_->value.AsMatrixDiag()->numRows = num_rows; } +void MatrixDiag::SetNumCols(int num_cols) { this->primitive_->value.AsMatrixDiag()->numCols = num_cols; } void MatrixDiag::SetPaddingValue(float padding_value) { - this->primitive->value.AsMatrixDiag()->paddingValue = padding_value; + this->primitive_->value.AsMatrixDiag()->paddingValue = padding_value; } #else -int MatrixDiag::GetK() const { return this->primitive->value_as_MatrixDiag()->k(); } -int MatrixDiag::GetNumRows() const { return this->primitive->value_as_MatrixDiag()->numRows(); } -int MatrixDiag::GetNumCols() const { return this->primitive->value_as_MatrixDiag()->numCols(); } -float MatrixDiag::GetPaddingValue() const { return this->primitive->value_as_MatrixDiag()->paddingValue(); } +int MatrixDiag::GetK() const { return this->primitive_->value_as_MatrixDiag()->k(); } +int MatrixDiag::GetNumRows() const { return this->primitive_->value_as_MatrixDiag()->numRows(); } +int MatrixDiag::GetNumCols() const { return this->primitive_->value_as_MatrixDiag()->numCols(); } +float MatrixDiag::GetPaddingValue() const { return this->primitive_->value_as_MatrixDiag()->paddingValue(); } void MatrixDiag::SetK(int k) {} void MatrixDiag::SetNumRows(int num_rows) {} diff --git a/mindspore/lite/src/ops/matrix_diag.h b/mindspore/lite/src/ops/matrix_diag.h index a619c8dcfec..14c41d60a8e 100644 --- a/mindspore/lite/src/ops/matrix_diag.h +++ b/mindspore/lite/src/ops/matrix_diag.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_MATRIX_DIAG_H_ +#define LITE_MINDSPORE_LITE_C_OPS_MATRIX_DIAG_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_MATRIX_DIAG_H_ -#define LITE_MINDSPORE_LITE_C_OPS_MATRIX_DIAG_H_ namespace mindspore { namespace lite { class MatrixDiag : public PrimitiveC { public: - explicit MatrixDiag(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit MatrixDiag(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit MatrixDiag(schema::Primitive *primitive) : PrimitiveC(primitive) {} int GetK() const; int GetNumRows() const; diff --git a/mindspore/lite/src/ops/maximum.h b/mindspore/lite/src/ops/maximum.h index 97e8c938c37..ba391b71acf 100644 --- a/mindspore/lite/src/ops/maximum.h +++ b/mindspore/lite/src/ops/maximum.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_MAXIMUM_H_ +#define LITE_MINDSPORE_LITE_C_OPS_MAXIMUM_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/arithmetic.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_MAXIMUM_H_ -#define LITE_MINDSPORE_LITE_C_OPS_MAXIMUM_H_ +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { class Maximum : public Arithmetic { public: - explicit Maximum(OriginPrimitive *primitive) : Arithmetic(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Maximum(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + explicit Maximum(schema::Primitive *primitive) : Arithmetic(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/mean.cc b/mindspore/lite/src/ops/mean.cc index e3986fdebbc..d6cfb5e204f 100644 --- a/mindspore/lite/src/ops/mean.cc +++ b/mindspore/lite/src/ops/mean.cc @@ -19,19 +19,19 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector Mean::GetAxis() const { return this->primitive->value.AsMean()->axis; } -bool Mean::GetKeepDims() const { return this->primitive->value.AsMean()->keepDims; } +std::vector Mean::GetAxis() const { return this->primitive_->value.AsMean()->axis; } +bool Mean::GetKeepDims() const { return this->primitive_->value.AsMean()->keepDims; } -void Mean::SetAxis(const std::vector &axis) { this->primitive->value.AsMean()->axis = axis; } -void Mean::SetKeepDims(bool keep_dims) { this->primitive->value.AsMean()->keepDims = keep_dims; } +void Mean::SetAxis(const std::vector &axis) { this->primitive_->value.AsMean()->axis = axis; } +void Mean::SetKeepDims(bool keep_dims) { this->primitive_->value.AsMean()->keepDims = keep_dims; } #else std::vector Mean::GetAxis() const { - auto fb_vector = this->primitive->value_as_Mean()->axis(); + auto fb_vector = this->primitive_->value_as_Mean()->axis(); return std::vector(fb_vector->begin(), fb_vector->end()); } -bool Mean::GetKeepDims() const { return this->primitive->value_as_Mean()->keepDims(); } +bool Mean::GetKeepDims() const { return this->primitive_->value_as_Mean()->keepDims(); } void Mean::SetAxis(const std::vector &axis) {} void Mean::SetKeepDims(bool keep_dims) {} @@ -55,7 +55,7 @@ int Mean::InferShape(std::vector inputs_, std::vectorprimitive == nullptr) { + if (this->primitive_ == nullptr) { return RET_NULL_PTR; } diff --git a/mindspore/lite/src/ops/mean.h b/mindspore/lite/src/ops/mean.h index 6c5927f5219..873cc75055b 100644 --- a/mindspore/lite/src/ops/mean.h +++ b/mindspore/lite/src/ops/mean.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_MEAN_H_ +#define LITE_MINDSPORE_LITE_C_OPS_MEAN_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_MEAN_H_ -#define LITE_MINDSPORE_LITE_C_OPS_MEAN_H_ namespace mindspore { namespace lite { class Mean : public PrimitiveC { public: - explicit Mean(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Mean(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Mean(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetAxis() const; diff --git a/mindspore/lite/src/ops/minimum.h b/mindspore/lite/src/ops/minimum.h index 36bb5ff77a6..1b11fdaca35 100644 --- a/mindspore/lite/src/ops/minimum.h +++ b/mindspore/lite/src/ops/minimum.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_MINIMUM_H_ +#define LITE_MINDSPORE_LITE_C_OPS_MINIMUM_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/arithmetic.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_MINIMUM_H_ -#define LITE_MINDSPORE_LITE_C_OPS_MINIMUM_H_ +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { class Minimum : public Arithmetic { public: - explicit Minimum(OriginPrimitive *primitive) : Arithmetic(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Minimum(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + explicit Minimum(schema::Primitive *primitive) : Arithmetic(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/mul.cc b/mindspore/lite/src/ops/mul.cc index c02baa7c0c4..fea06d35388 100644 --- a/mindspore/lite/src/ops/mul.cc +++ b/mindspore/lite/src/ops/mul.cc @@ -19,15 +19,15 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Mul::GetActivationType() const { return this->primitive->value.AsMul()->activationType; } +int Mul::GetActivationType() const { return this->primitive_->value.AsMul()->activationType; } void Mul::SetActivationType(int activation_type) { - this->primitive->value.AsMul()->activationType = (schema::ActivationType) activation_type; + this->primitive_->value.AsMul()->activationType = (schema::ActivationType)activation_type; } #else -int Mul::GetActivationType() const { return this->primitive->value_as_Mul()->activationType(); } +int Mul::GetActivationType() const { return this->primitive_->value_as_Mul()->activationType(); } void Mul::SetActivationType(int activation_type) {} #endif diff --git a/mindspore/lite/src/ops/mul.h b/mindspore/lite/src/ops/mul.h index 2a8ede87415..97ca5fccabc 100644 --- a/mindspore/lite/src/ops/mul.h +++ b/mindspore/lite/src/ops/mul.h @@ -14,26 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_MUL_H_ +#define LITE_MINDSPORE_LITE_C_OPS_MUL_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" #include "src/ops/arithmetic.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_MUL_H_ -#define LITE_MINDSPORE_LITE_C_OPS_MUL_H_ namespace mindspore { namespace lite { class Mul : public Arithmetic { public: - explicit Mul(OriginPrimitive *primitive) : Arithmetic(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Mul(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + explicit Mul(schema::Primitive *primitive) : Arithmetic(primitive) {} int GetActivationType() const; void SetActivationType(int activation_type); diff --git a/mindspore/lite/src/ops/nchw2nhwc.cc b/mindspore/lite/src/ops/nchw2nhwc.cc index a18558d23f3..170f04020f1 100644 --- a/mindspore/lite/src/ops/nchw2nhwc.cc +++ b/mindspore/lite/src/ops/nchw2nhwc.cc @@ -20,7 +20,7 @@ namespace mindspore { namespace lite { int Nchw2Nhwc::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); diff --git a/mindspore/lite/src/ops/nchw2nhwc.h b/mindspore/lite/src/ops/nchw2nhwc.h index edd85e23642..3ea40ecc67f 100644 --- a/mindspore/lite/src/ops/nchw2nhwc.h +++ b/mindspore/lite/src/ops/nchw2nhwc.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_NCHW_2_NHWC_H_ +#define LITE_MINDSPORE_LITE_C_OPS_NCHW_2_NHWC_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_NCHW_2_NHWC_H_ -#define LITE_MINDSPORE_LITE_C_OPS_NCHW_2_NHWC_H_ namespace mindspore { namespace lite { class Nchw2Nhwc : public PrimitiveC { public: - explicit Nchw2Nhwc(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Nchw2Nhwc(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Nchw2Nhwc(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; }; diff --git a/mindspore/lite/src/ops/nhwc2nchw.cc b/mindspore/lite/src/ops/nhwc2nchw.cc index 39b9b3854fd..9bff02ab2e4 100644 --- a/mindspore/lite/src/ops/nhwc2nchw.cc +++ b/mindspore/lite/src/ops/nhwc2nchw.cc @@ -20,7 +20,7 @@ namespace mindspore { namespace lite { int Nhwc2Nchw::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); diff --git a/mindspore/lite/src/ops/nhwc2nchw.h b/mindspore/lite/src/ops/nhwc2nchw.h index 54cab1716c9..c7f05f845b1 100644 --- a/mindspore/lite/src/ops/nhwc2nchw.h +++ b/mindspore/lite/src/ops/nhwc2nchw.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_NHWC_2_NCHW_H_ +#define LITE_MINDSPORE_LITE_C_OPS_NHWC_2_NCHW_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_NHWC_2_NCHW_H_ -#define LITE_MINDSPORE_LITE_C_OPS_NHWC_2_NCHW_H_ namespace mindspore { namespace lite { class Nhwc2Nchw : public PrimitiveC { public: - explicit Nhwc2Nchw(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Nhwc2Nchw(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Nhwc2Nchw(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; }; diff --git a/mindspore/lite/src/ops/not_equal.h b/mindspore/lite/src/ops/not_equal.h index 18e10176665..f7c9a77bacd 100644 --- a/mindspore/lite/src/ops/not_equal.h +++ b/mindspore/lite/src/ops/not_equal.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_NOT_EQUAL_H_ +#define LITE_MINDSPORE_LITE_C_OPS_NOT_EQUAL_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/arithmetic.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_NOT_EQUAL_H_ -#define LITE_MINDSPORE_LITE_C_OPS_NOT_EQUAL_H_ +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { class NotEqual : public Arithmetic { public: - explicit NotEqual(OriginPrimitive *primitive) : Arithmetic(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit NotEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + explicit NotEqual(schema::Primitive *primitive) : Arithmetic(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/one_hot.cc b/mindspore/lite/src/ops/one_hot.cc index 7398361adea..da1ef942788 100644 --- a/mindspore/lite/src/ops/one_hot.cc +++ b/mindspore/lite/src/ops/one_hot.cc @@ -19,13 +19,13 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int OneHot::GetAxis() const { return this->primitive->value.AsOneHot()->axis; } +int OneHot::GetAxis() const { return this->primitive_->value.AsOneHot()->axis; } -void OneHot::SetAxis(int axis) { this->primitive->value.AsOneHot()->axis = axis; } +void OneHot::SetAxis(int axis) { this->primitive_->value.AsOneHot()->axis = axis; } #else -int OneHot::GetAxis() const { return this->primitive->value_as_OneHot()->axis(); } +int OneHot::GetAxis() const { return this->primitive_->value_as_OneHot()->axis(); } void OneHot::SetAxis(int axis) {} #endif @@ -34,7 +34,7 @@ namespace { constexpr size_t kOneHotInputNum = 4; } int OneHot::InferShape(std::vector inputs, std::vector outputs) { - if (this->primitive == nullptr) { + if (this->primitive_ == nullptr) { return RET_NULL_PTR; } diff --git a/mindspore/lite/src/ops/one_hot.h b/mindspore/lite/src/ops/one_hot.h index 7193178d4ae..e9703c0d606 100644 --- a/mindspore/lite/src/ops/one_hot.h +++ b/mindspore/lite/src/ops/one_hot.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_ONE_HOT_H_ +#define LITE_MINDSPORE_LITE_C_OPS_ONE_HOT_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_ONE_HOT_H_ -#define LITE_MINDSPORE_LITE_C_OPS_ONE_HOT_H_ namespace mindspore { namespace lite { class OneHot : public PrimitiveC { public: - explicit OneHot(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit OneHot(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit OneHot(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; diff --git a/mindspore/lite/src/ops/pad.cc b/mindspore/lite/src/ops/pad.cc index 45bae3ace2d..eb116d789d0 100644 --- a/mindspore/lite/src/ops/pad.cc +++ b/mindspore/lite/src/ops/pad.cc @@ -19,24 +19,24 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector Pad::GetPaddings() const { return this->primitive->value.AsPad()->paddings; } -int Pad::GetPaddingMode() const { return this->primitive->value.AsPad()->paddingMode; } -float Pad::GetConstantValue() const { return this->primitive->value.AsPad()->constantValue; } +std::vector Pad::GetPaddings() const { return this->primitive_->value.AsPad()->paddings; } +int Pad::GetPaddingMode() const { return this->primitive_->value.AsPad()->paddingMode; } +float Pad::GetConstantValue() const { return this->primitive_->value.AsPad()->constantValue; } -void Pad::SetPaddings(const std::vector &paddings) { this->primitive->value.AsPad()->paddings = paddings; } +void Pad::SetPaddings(const std::vector &paddings) { this->primitive_->value.AsPad()->paddings = paddings; } void Pad::SetPaddingMode(int padding_mode) { - this->primitive->value.AsPad()->paddingMode = (schema::PaddingMode) padding_mode; + this->primitive_->value.AsPad()->paddingMode = (schema::PaddingMode)padding_mode; } -void Pad::SetConstantValue(float constant_value) { this->primitive->value.AsPad()->constantValue = constant_value; } +void Pad::SetConstantValue(float constant_value) { this->primitive_->value.AsPad()->constantValue = constant_value; } #else std::vector Pad::GetPaddings() const { - auto fb_vector = this->primitive->value_as_Pad()->paddings(); + auto fb_vector = this->primitive_->value_as_Pad()->paddings(); return std::vector(fb_vector->begin(), fb_vector->end()); } -int Pad::GetPaddingMode() const { return this->primitive->value_as_Pad()->paddingMode(); } -float Pad::GetConstantValue() const { return this->primitive->value_as_Pad()->constantValue(); } +int Pad::GetPaddingMode() const { return this->primitive_->value_as_Pad()->paddingMode(); } +float Pad::GetConstantValue() const { return this->primitive_->value_as_Pad()->constantValue(); } void Pad::SetPaddings(const std::vector &paddings) {} void Pad::SetPaddingMode(int padding_mode) {} @@ -46,8 +46,8 @@ namespace { const size_t kInputRank = 4; } // namespace int Pad::InferShape(std::vector inputs, std::vector outputs) { - MS_ASSERT(this->primitive != nullptr); - if (this->primitive == nullptr) { + MS_ASSERT(this->primitive_ != nullptr); + if (this->primitive_ == nullptr) { return RET_NULL_PTR; } diff --git a/mindspore/lite/src/ops/pad.h b/mindspore/lite/src/ops/pad.h index 94359d6fd91..bedeabf2263 100644 --- a/mindspore/lite/src/ops/pad.h +++ b/mindspore/lite/src/ops/pad.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_PAD_H_ +#define LITE_MINDSPORE_LITE_C_OPS_PAD_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_PAD_H_ -#define LITE_MINDSPORE_LITE_C_OPS_PAD_H_ namespace mindspore { namespace lite { class Pad : public PrimitiveC { public: - explicit Pad(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Pad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Pad(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetPaddings() const; diff --git a/mindspore/lite/src/ops/permute.cc b/mindspore/lite/src/ops/permute.cc index ea2045243bc..235d7a03f81 100644 --- a/mindspore/lite/src/ops/permute.cc +++ b/mindspore/lite/src/ops/permute.cc @@ -19,14 +19,14 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector Permute::GetOrder() const { return this->primitive->value.AsPermute()->order; } +std::vector Permute::GetOrder() const { return this->primitive_->value.AsPermute()->order; } -void Permute::SetOrder(const std::vector &order) { this->primitive->value.AsPermute()->order = order; } +void Permute::SetOrder(const std::vector &order) { this->primitive_->value.AsPermute()->order = order; } #else std::vector Permute::GetOrder() const { - auto fb_vector = this->primitive->value_as_Permute()->order(); + auto fb_vector = this->primitive_->value_as_Permute()->order(); return std::vector(fb_vector->begin(), fb_vector->end()); } diff --git a/mindspore/lite/src/ops/permute.h b/mindspore/lite/src/ops/permute.h index 53c0bafa589..2b5ee61efd9 100644 --- a/mindspore/lite/src/ops/permute.h +++ b/mindspore/lite/src/ops/permute.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_PERMUTE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_PERMUTE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_PERMUTE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_PERMUTE_H_ namespace mindspore { namespace lite { class Permute : public PrimitiveC { public: - explicit Permute(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Permute(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Permute(schema::Primitive *primitive) : PrimitiveC(primitive) {} std::vector GetOrder() const; void SetOrder(const std::vector &order); diff --git a/mindspore/lite/src/ops/pooling.cc b/mindspore/lite/src/ops/pooling.cc index ac08607d675..f24ec9600fe 100644 --- a/mindspore/lite/src/ops/pooling.cc +++ b/mindspore/lite/src/ops/pooling.cc @@ -20,53 +20,53 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Pooling::GetFormat() const { return this->primitive->value.AsPooling()->format; } -int Pooling::GetPoolingMode() const { return this->primitive->value.AsPooling()->poolingMode; } -bool Pooling::GetGlobal() const { return this->primitive->value.AsPooling()->global; } -int Pooling::GetWindowW() const { return this->primitive->value.AsPooling()->windowW; } -int Pooling::GetWindowH() const { return this->primitive->value.AsPooling()->windowH; } -int Pooling::GetStrideW() const { return this->primitive->value.AsPooling()->strideW; } -int Pooling::GetStrideH() const { return this->primitive->value.AsPooling()->strideH; } -int Pooling::GetPadMode() const { return this->primitive->value.AsPooling()->padMode; } -int Pooling::GetPadUp() const { return this->primitive->value.AsPooling()->padUp; } -int Pooling::GetPadDown() const { return this->primitive->value.AsPooling()->padDown; } -int Pooling::GetPadLeft() const { return this->primitive->value.AsPooling()->padLeft; } -int Pooling::GetPadRight() const { return this->primitive->value.AsPooling()->padRight; } -int Pooling::GetRoundMode() const { return this->primitive->value.AsPooling()->roundMode; } +int Pooling::GetFormat() const { return this->primitive_->value.AsPooling()->format; } +int Pooling::GetPoolingMode() const { return this->primitive_->value.AsPooling()->poolingMode; } +bool Pooling::GetGlobal() const { return this->primitive_->value.AsPooling()->global; } +int Pooling::GetWindowW() const { return this->primitive_->value.AsPooling()->windowW; } +int Pooling::GetWindowH() const { return this->primitive_->value.AsPooling()->windowH; } +int Pooling::GetStrideW() const { return this->primitive_->value.AsPooling()->strideW; } +int Pooling::GetStrideH() const { return this->primitive_->value.AsPooling()->strideH; } +int Pooling::GetPadMode() const { return this->primitive_->value.AsPooling()->padMode; } +int Pooling::GetPadUp() const { return this->primitive_->value.AsPooling()->padUp; } +int Pooling::GetPadDown() const { return this->primitive_->value.AsPooling()->padDown; } +int Pooling::GetPadLeft() const { return this->primitive_->value.AsPooling()->padLeft; } +int Pooling::GetPadRight() const { return this->primitive_->value.AsPooling()->padRight; } +int Pooling::GetRoundMode() const { return this->primitive_->value.AsPooling()->roundMode; } -void Pooling::SetFormat(int format) { this->primitive->value.AsPooling()->format = (schema::Format) format; } +void Pooling::SetFormat(int format) { this->primitive_->value.AsPooling()->format = (schema::Format)format; } void Pooling::SetPoolingMode(int pooling_mode) { - this->primitive->value.AsPooling()->poolingMode = (schema::PoolMode) pooling_mode; + this->primitive_->value.AsPooling()->poolingMode = (schema::PoolMode)pooling_mode; } -void Pooling::SetGlobal(bool global) { this->primitive->value.AsPooling()->global = global; } -void Pooling::SetWindowW(int window_w) { this->primitive->value.AsPooling()->windowW = window_w; } -void Pooling::SetWindowH(int window_h) { this->primitive->value.AsPooling()->windowH = window_h; } -void Pooling::SetStrideW(int stride_w) { this->primitive->value.AsPooling()->strideW = stride_w; } -void Pooling::SetStrideH(int stride_h) { this->primitive->value.AsPooling()->strideH = stride_h; } -void Pooling::SetPadMode(int pad_mode) { this->primitive->value.AsPooling()->padMode = (schema::PadMode) pad_mode; } -void Pooling::SetPadUp(int pad_up) { this->primitive->value.AsPooling()->padUp = pad_up; } -void Pooling::SetPadDown(int pad_down) { this->primitive->value.AsPooling()->padDown = pad_down; } -void Pooling::SetPadLeft(int pad_left) { this->primitive->value.AsPooling()->padLeft = pad_left; } -void Pooling::SetPadRight(int pad_right) { this->primitive->value.AsPooling()->padRight = pad_right; } +void Pooling::SetGlobal(bool global) { this->primitive_->value.AsPooling()->global = global; } +void Pooling::SetWindowW(int window_w) { this->primitive_->value.AsPooling()->windowW = window_w; } +void Pooling::SetWindowH(int window_h) { this->primitive_->value.AsPooling()->windowH = window_h; } +void Pooling::SetStrideW(int stride_w) { this->primitive_->value.AsPooling()->strideW = stride_w; } +void Pooling::SetStrideH(int stride_h) { this->primitive_->value.AsPooling()->strideH = stride_h; } +void Pooling::SetPadMode(int pad_mode) { this->primitive_->value.AsPooling()->padMode = (schema::PadMode)pad_mode; } +void Pooling::SetPadUp(int pad_up) { this->primitive_->value.AsPooling()->padUp = pad_up; } +void Pooling::SetPadDown(int pad_down) { this->primitive_->value.AsPooling()->padDown = pad_down; } +void Pooling::SetPadLeft(int pad_left) { this->primitive_->value.AsPooling()->padLeft = pad_left; } +void Pooling::SetPadRight(int pad_right) { this->primitive_->value.AsPooling()->padRight = pad_right; } void Pooling::SetRoundMode(int round_mode) { - this->primitive->value.AsPooling()->roundMode = (schema::RoundMode) round_mode; + this->primitive_->value.AsPooling()->roundMode = (schema::RoundMode)round_mode; } #else -int Pooling::GetFormat() const { return this->primitive->value_as_Pooling()->format(); } -int Pooling::GetPoolingMode() const { return this->primitive->value_as_Pooling()->poolingMode(); } -bool Pooling::GetGlobal() const { return this->primitive->value_as_Pooling()->global(); } -int Pooling::GetWindowW() const { return this->primitive->value_as_Pooling()->windowW(); } -int Pooling::GetWindowH() const { return this->primitive->value_as_Pooling()->windowH(); } -int Pooling::GetStrideW() const { return this->primitive->value_as_Pooling()->strideW(); } -int Pooling::GetStrideH() const { return this->primitive->value_as_Pooling()->strideH(); } -int Pooling::GetPadMode() const { return this->primitive->value_as_Pooling()->padMode(); } -int Pooling::GetPadUp() const { return this->primitive->value_as_Pooling()->padUp(); } -int Pooling::GetPadDown() const { return this->primitive->value_as_Pooling()->padDown(); } -int Pooling::GetPadLeft() const { return this->primitive->value_as_Pooling()->padLeft(); } -int Pooling::GetPadRight() const { return this->primitive->value_as_Pooling()->padRight(); } -int Pooling::GetRoundMode() const { return this->primitive->value_as_Pooling()->roundMode(); } +int Pooling::GetFormat() const { return this->primitive_->value_as_Pooling()->format(); } +int Pooling::GetPoolingMode() const { return this->primitive_->value_as_Pooling()->poolingMode(); } +bool Pooling::GetGlobal() const { return this->primitive_->value_as_Pooling()->global(); } +int Pooling::GetWindowW() const { return this->primitive_->value_as_Pooling()->windowW(); } +int Pooling::GetWindowH() const { return this->primitive_->value_as_Pooling()->windowH(); } +int Pooling::GetStrideW() const { return this->primitive_->value_as_Pooling()->strideW(); } +int Pooling::GetStrideH() const { return this->primitive_->value_as_Pooling()->strideH(); } +int Pooling::GetPadMode() const { return this->primitive_->value_as_Pooling()->padMode(); } +int Pooling::GetPadUp() const { return this->primitive_->value_as_Pooling()->padUp(); } +int Pooling::GetPadDown() const { return this->primitive_->value_as_Pooling()->padDown(); } +int Pooling::GetPadLeft() const { return this->primitive_->value_as_Pooling()->padLeft(); } +int Pooling::GetPadRight() const { return this->primitive_->value_as_Pooling()->padRight(); } +int Pooling::GetRoundMode() const { return this->primitive_->value_as_Pooling()->roundMode(); } void Pooling::SetFormat(int format) {} void Pooling::SetPoolingMode(int pooling_mode) {} @@ -90,7 +90,7 @@ int Pooling::PadLeft() const { return this->pad_l_; } int Pooling::PadRight() const { return this->pad_r_; } int Pooling::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); @@ -126,7 +126,7 @@ int Pooling::InferShape(std::vector inputs_, std::vector(input_h + pad_u_ + pad_d_ - window_h) / GetStrideH()) + 1; output_w = std::floor(static_cast(input_w + pad_l_ + pad_r_ - window_w) / GetStrideW()) + 1; diff --git a/mindspore/lite/src/ops/pooling.h b/mindspore/lite/src/ops/pooling.h index 84e42590a60..77bf8f261ac 100644 --- a/mindspore/lite/src/ops/pooling.h +++ b/mindspore/lite/src/ops/pooling.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_POOLING_H_ +#define LITE_MINDSPORE_LITE_C_OPS_POOLING_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_POOLING_H_ -#define LITE_MINDSPORE_LITE_C_OPS_POOLING_H_ namespace mindspore { namespace lite { class Pooling : public PrimitiveC { public: - explicit Pooling(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Pooling(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Pooling(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; diff --git a/mindspore/lite/src/ops/pooling_grad.cc b/mindspore/lite/src/ops/pooling_grad.cc index 345c5d04daf..f4e28d9f9a5 100644 --- a/mindspore/lite/src/ops/pooling_grad.cc +++ b/mindspore/lite/src/ops/pooling_grad.cc @@ -19,55 +19,55 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int PoolingGrad::GetFormat() const { return this->primitive->value.AsPoolingGrad()->format; } -int PoolingGrad::GetPoolingMode() const { return this->primitive->value.AsPoolingGrad()->poolingMode; } -bool PoolingGrad::GetGlobal() const { return this->primitive->value.AsPoolingGrad()->global; } -int PoolingGrad::GetWindowW() const { return this->primitive->value.AsPoolingGrad()->windowW; } -int PoolingGrad::GetWindowH() const { return this->primitive->value.AsPoolingGrad()->windowH; } -int PoolingGrad::GetStrideW() const { return this->primitive->value.AsPoolingGrad()->strideW; } -int PoolingGrad::GetStrideH() const { return this->primitive->value.AsPoolingGrad()->strideH; } -int PoolingGrad::GetPadMode() const { return this->primitive->value.AsPoolingGrad()->padMode; } -int PoolingGrad::GetPadUp() const { return this->primitive->value.AsPoolingGrad()->padUp; } -int PoolingGrad::GetPadDown() const { return this->primitive->value.AsPoolingGrad()->padDown; } -int PoolingGrad::GetPadLeft() const { return this->primitive->value.AsPoolingGrad()->padLeft; } -int PoolingGrad::GetPadRight() const { return this->primitive->value.AsPoolingGrad()->padRight; } -int PoolingGrad::GetRoundMode() const { return this->primitive->value.AsPoolingGrad()->roundMode; } +int PoolingGrad::GetFormat() const { return this->primitive_->value.AsPoolingGrad()->format; } +int PoolingGrad::GetPoolingMode() const { return this->primitive_->value.AsPoolingGrad()->poolingMode; } +bool PoolingGrad::GetGlobal() const { return this->primitive_->value.AsPoolingGrad()->global; } +int PoolingGrad::GetWindowW() const { return this->primitive_->value.AsPoolingGrad()->windowW; } +int PoolingGrad::GetWindowH() const { return this->primitive_->value.AsPoolingGrad()->windowH; } +int PoolingGrad::GetStrideW() const { return this->primitive_->value.AsPoolingGrad()->strideW; } +int PoolingGrad::GetStrideH() const { return this->primitive_->value.AsPoolingGrad()->strideH; } +int PoolingGrad::GetPadMode() const { return this->primitive_->value.AsPoolingGrad()->padMode; } +int PoolingGrad::GetPadUp() const { return this->primitive_->value.AsPoolingGrad()->padUp; } +int PoolingGrad::GetPadDown() const { return this->primitive_->value.AsPoolingGrad()->padDown; } +int PoolingGrad::GetPadLeft() const { return this->primitive_->value.AsPoolingGrad()->padLeft; } +int PoolingGrad::GetPadRight() const { return this->primitive_->value.AsPoolingGrad()->padRight; } +int PoolingGrad::GetRoundMode() const { return this->primitive_->value.AsPoolingGrad()->roundMode; } -void PoolingGrad::SetFormat(int format) { this->primitive->value.AsPoolingGrad()->format = (schema::Format)format; } +void PoolingGrad::SetFormat(int format) { this->primitive_->value.AsPoolingGrad()->format = (schema::Format)format; } void PoolingGrad::SetPoolingMode(int pooling_mode) { - this->primitive->value.AsPoolingGrad()->poolingMode = (schema::PoolMode)pooling_mode; + this->primitive_->value.AsPoolingGrad()->poolingMode = (schema::PoolMode)pooling_mode; } -void PoolingGrad::SetGlobal(bool global) { this->primitive->value.AsPoolingGrad()->global = global; } -void PoolingGrad::SetWindowW(int window_w) { this->primitive->value.AsPoolingGrad()->windowW = window_w; } -void PoolingGrad::SetWindowH(int window_h) { this->primitive->value.AsPoolingGrad()->windowH = window_h; } -void PoolingGrad::SetStrideW(int stride_w) { this->primitive->value.AsPoolingGrad()->strideW = stride_w; } -void PoolingGrad::SetStrideH(int stride_h) { this->primitive->value.AsPoolingGrad()->strideH = stride_h; } +void PoolingGrad::SetGlobal(bool global) { this->primitive_->value.AsPoolingGrad()->global = global; } +void PoolingGrad::SetWindowW(int window_w) { this->primitive_->value.AsPoolingGrad()->windowW = window_w; } +void PoolingGrad::SetWindowH(int window_h) { this->primitive_->value.AsPoolingGrad()->windowH = window_h; } +void PoolingGrad::SetStrideW(int stride_w) { this->primitive_->value.AsPoolingGrad()->strideW = stride_w; } +void PoolingGrad::SetStrideH(int stride_h) { this->primitive_->value.AsPoolingGrad()->strideH = stride_h; } void PoolingGrad::SetPadMode(int pad_mode) { - this->primitive->value.AsPoolingGrad()->padMode = (schema::PadMode)pad_mode; + this->primitive_->value.AsPoolingGrad()->padMode = (schema::PadMode)pad_mode; } -void PoolingGrad::SetPadUp(int pad_up) { this->primitive->value.AsPoolingGrad()->padUp = pad_up; } -void PoolingGrad::SetPadDown(int pad_down) { this->primitive->value.AsPoolingGrad()->padDown = pad_down; } -void PoolingGrad::SetPadLeft(int pad_left) { this->primitive->value.AsPoolingGrad()->padLeft = pad_left; } -void PoolingGrad::SetPadRight(int pad_right) { this->primitive->value.AsPoolingGrad()->padRight = pad_right; } +void PoolingGrad::SetPadUp(int pad_up) { this->primitive_->value.AsPoolingGrad()->padUp = pad_up; } +void PoolingGrad::SetPadDown(int pad_down) { this->primitive_->value.AsPoolingGrad()->padDown = pad_down; } +void PoolingGrad::SetPadLeft(int pad_left) { this->primitive_->value.AsPoolingGrad()->padLeft = pad_left; } +void PoolingGrad::SetPadRight(int pad_right) { this->primitive_->value.AsPoolingGrad()->padRight = pad_right; } void PoolingGrad::SetRoundMode(int round_mode) { - this->primitive->value.AsPoolingGrad()->roundMode = (schema::RoundMode)round_mode; + this->primitive_->value.AsPoolingGrad()->roundMode = (schema::RoundMode)round_mode; } #else -int PoolingGrad::GetFormat() const { return this->primitive->value_as_PoolingGrad()->format(); } -int PoolingGrad::GetPoolingMode() const { return this->primitive->value_as_PoolingGrad()->poolingMode(); } -bool PoolingGrad::GetGlobal() const { return this->primitive->value_as_PoolingGrad()->global(); } -int PoolingGrad::GetWindowW() const { return this->primitive->value_as_PoolingGrad()->windowW(); } -int PoolingGrad::GetWindowH() const { return this->primitive->value_as_PoolingGrad()->windowH(); } -int PoolingGrad::GetStrideW() const { return this->primitive->value_as_PoolingGrad()->strideW(); } -int PoolingGrad::GetStrideH() const { return this->primitive->value_as_PoolingGrad()->strideH(); } -int PoolingGrad::GetPadMode() const { return this->primitive->value_as_PoolingGrad()->padMode(); } -int PoolingGrad::GetPadUp() const { return this->primitive->value_as_PoolingGrad()->padUp(); } -int PoolingGrad::GetPadDown() const { return this->primitive->value_as_PoolingGrad()->padDown(); } -int PoolingGrad::GetPadLeft() const { return this->primitive->value_as_PoolingGrad()->padLeft(); } -int PoolingGrad::GetPadRight() const { return this->primitive->value_as_PoolingGrad()->padRight(); } -int PoolingGrad::GetRoundMode() const { return this->primitive->value_as_PoolingGrad()->roundMode(); } +int PoolingGrad::GetFormat() const { return this->primitive_->value_as_PoolingGrad()->format(); } +int PoolingGrad::GetPoolingMode() const { return this->primitive_->value_as_PoolingGrad()->poolingMode(); } +bool PoolingGrad::GetGlobal() const { return this->primitive_->value_as_PoolingGrad()->global(); } +int PoolingGrad::GetWindowW() const { return this->primitive_->value_as_PoolingGrad()->windowW(); } +int PoolingGrad::GetWindowH() const { return this->primitive_->value_as_PoolingGrad()->windowH(); } +int PoolingGrad::GetStrideW() const { return this->primitive_->value_as_PoolingGrad()->strideW(); } +int PoolingGrad::GetStrideH() const { return this->primitive_->value_as_PoolingGrad()->strideH(); } +int PoolingGrad::GetPadMode() const { return this->primitive_->value_as_PoolingGrad()->padMode(); } +int PoolingGrad::GetPadUp() const { return this->primitive_->value_as_PoolingGrad()->padUp(); } +int PoolingGrad::GetPadDown() const { return this->primitive_->value_as_PoolingGrad()->padDown(); } +int PoolingGrad::GetPadLeft() const { return this->primitive_->value_as_PoolingGrad()->padLeft(); } +int PoolingGrad::GetPadRight() const { return this->primitive_->value_as_PoolingGrad()->padRight(); } +int PoolingGrad::GetRoundMode() const { return this->primitive_->value_as_PoolingGrad()->roundMode(); } void PoolingGrad::SetFormat(int format) {} void PoolingGrad::SetPoolingMode(int pooling_mode) {} diff --git a/mindspore/lite/src/ops/pooling_grad.h b/mindspore/lite/src/ops/pooling_grad.h index 617a3a62d5e..4314b22d3a5 100644 --- a/mindspore/lite/src/ops/pooling_grad.h +++ b/mindspore/lite/src/ops/pooling_grad.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_POOLING_GRAD_H_ +#define LITE_MINDSPORE_LITE_C_OPS_POOLING_GRAD_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_POOLING_GRAD_H_ -#define LITE_MINDSPORE_LITE_C_OPS_POOLING_GRAD_H_ namespace mindspore { namespace lite { class PoolingGrad : public PrimitiveC { public: - explicit PoolingGrad(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit PoolingGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit PoolingGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {} int GetFormat() const; int GetPoolingMode() const; diff --git a/mindspore/lite/src/ops/power.cc b/mindspore/lite/src/ops/power.cc index 82398aa6836..ff485c0a715 100644 --- a/mindspore/lite/src/ops/power.cc +++ b/mindspore/lite/src/ops/power.cc @@ -19,19 +19,19 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -float Power::GetPower() const { return this->primitive->value.AsPower()->power; } -float Power::GetScale() const { return this->primitive->value.AsPower()->scale; } -float Power::GetShift() const { return this->primitive->value.AsPower()->shift; } +float Power::GetPower() const { return this->primitive_->value.AsPower()->power; } +float Power::GetScale() const { return this->primitive_->value.AsPower()->scale; } +float Power::GetShift() const { return this->primitive_->value.AsPower()->shift; } -void Power::SetPower(float power) { this->primitive->value.AsPower()->power = power; } -void Power::SetScale(float scale) { this->primitive->value.AsPower()->scale = scale; } -void Power::SetShift(float shift) { this->primitive->value.AsPower()->shift = shift; } +void Power::SetPower(float power) { this->primitive_->value.AsPower()->power = power; } +void Power::SetScale(float scale) { this->primitive_->value.AsPower()->scale = scale; } +void Power::SetShift(float shift) { this->primitive_->value.AsPower()->shift = shift; } #else -float Power::GetPower() const { return this->primitive->value_as_Power()->power(); } -float Power::GetScale() const { return this->primitive->value_as_Power()->scale(); } -float Power::GetShift() const { return this->primitive->value_as_Power()->shift(); } +float Power::GetPower() const { return this->primitive_->value_as_Power()->power(); } +float Power::GetScale() const { return this->primitive_->value_as_Power()->scale(); } +float Power::GetShift() const { return this->primitive_->value_as_Power()->shift(); } void Power::SetPower(float power) {} void Power::SetScale(float scale) {} @@ -39,7 +39,7 @@ void Power::SetShift(float shift) {} #endif int Power::InferShape(std::vector inputs, std::vector outputs) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto x_tensor = inputs[0]; MS_ASSERT(x_tensor != nullptr); tensor::Tensor *exp_tensor = nullptr; diff --git a/mindspore/lite/src/ops/power.h b/mindspore/lite/src/ops/power.h index cb68962659b..1f252cf2c93 100644 --- a/mindspore/lite/src/ops/power.h +++ b/mindspore/lite/src/ops/power.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_POWER_H_ +#define LITE_MINDSPORE_LITE_C_OPS_POWER_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_POWER_H_ -#define LITE_MINDSPORE_LITE_C_OPS_POWER_H_ namespace mindspore { namespace lite { class Power : public PrimitiveC { public: - explicit Power(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Power(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Power(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; float GetPower() const; diff --git a/mindspore/lite/src/ops/power_grad.cc b/mindspore/lite/src/ops/power_grad.cc index b898221ae24..0e9056f4589 100644 --- a/mindspore/lite/src/ops/power_grad.cc +++ b/mindspore/lite/src/ops/power_grad.cc @@ -19,19 +19,19 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -float PowerGrad::GetPower() const { return this->primitive->value.AsPowerGrad()->power; } -float PowerGrad::GetScale() const { return this->primitive->value.AsPowerGrad()->scale; } -float PowerGrad::GetShift() const { return this->primitive->value.AsPowerGrad()->shift; } +float PowerGrad::GetPower() const { return this->primitive_->value.AsPowerGrad()->power; } +float PowerGrad::GetScale() const { return this->primitive_->value.AsPowerGrad()->scale; } +float PowerGrad::GetShift() const { return this->primitive_->value.AsPowerGrad()->shift; } -void PowerGrad::SetPower(float power) { this->primitive->value.AsPowerGrad()->power = power; } -void PowerGrad::SetScale(float scale) { this->primitive->value.AsPowerGrad()->scale = scale; } -void PowerGrad::SetShift(float shift) { this->primitive->value.AsPowerGrad()->shift = shift; } +void PowerGrad::SetPower(float power) { this->primitive_->value.AsPowerGrad()->power = power; } +void PowerGrad::SetScale(float scale) { this->primitive_->value.AsPowerGrad()->scale = scale; } +void PowerGrad::SetShift(float shift) { this->primitive_->value.AsPowerGrad()->shift = shift; } #else -float PowerGrad::GetPower() const { return this->primitive->value_as_PowerGrad()->power(); } -float PowerGrad::GetScale() const { return this->primitive->value_as_PowerGrad()->scale(); } -float PowerGrad::GetShift() const { return this->primitive->value_as_PowerGrad()->shift(); } +float PowerGrad::GetPower() const { return this->primitive_->value_as_PowerGrad()->power(); } +float PowerGrad::GetScale() const { return this->primitive_->value_as_PowerGrad()->scale(); } +float PowerGrad::GetShift() const { return this->primitive_->value_as_PowerGrad()->shift(); } void PowerGrad::SetPower(float power) {} void PowerGrad::SetScale(float scale) {} diff --git a/mindspore/lite/src/ops/power_grad.h b/mindspore/lite/src/ops/power_grad.h index bb55b935210..29bd56515d6 100644 --- a/mindspore/lite/src/ops/power_grad.h +++ b/mindspore/lite/src/ops/power_grad.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_POWER_GRAD_H_ +#define LITE_MINDSPORE_LITE_C_OPS_POWER_GRAD_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_POWER_GRAD_H_ -#define LITE_MINDSPORE_LITE_C_OPS_POWER_GRAD_H_ namespace mindspore { namespace lite { class PowerGrad : public PrimitiveC { public: - explicit PowerGrad(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit PowerGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit PowerGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {} float GetPower() const; float GetScale() const; diff --git a/mindspore/lite/src/ops/prelu.cc b/mindspore/lite/src/ops/prelu.cc index f0c53e8cd3d..1bca56a3b65 100644 --- a/mindspore/lite/src/ops/prelu.cc +++ b/mindspore/lite/src/ops/prelu.cc @@ -19,14 +19,14 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector Prelu::GetSlope() const { return this->primitive->value.AsPrelu()->slope; } +std::vector Prelu::GetSlope() const { return this->primitive_->value.AsPrelu()->slope; } -void Prelu::SetSlope(const std::vector &slope) { this->primitive->value.AsPrelu()->slope = slope; } +void Prelu::SetSlope(const std::vector &slope) { this->primitive_->value.AsPrelu()->slope = slope; } #else std::vector Prelu::GetSlope() const { - auto fb_vector = this->primitive->value_as_Prelu()->slope(); + auto fb_vector = this->primitive_->value_as_Prelu()->slope(); return std::vector(fb_vector->begin(), fb_vector->end()); } diff --git a/mindspore/lite/src/ops/prelu.h b/mindspore/lite/src/ops/prelu.h index 87bd188b21d..f57fd9f5c91 100644 --- a/mindspore/lite/src/ops/prelu.h +++ b/mindspore/lite/src/ops/prelu.h @@ -14,27 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_PRELU_H_ +#define LITE_MINDSPORE_LITE_C_OPS_PRELU_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" #include "src/ops/activation.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_PRELU_H_ -#define LITE_MINDSPORE_LITE_C_OPS_PRELU_H_ - namespace mindspore { namespace lite { class Prelu : public Activation { public: - explicit Prelu(OriginPrimitive *primitive) : Activation(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Prelu(schema::PrimitiveT *primitive) : Activation(primitive) {} +#endif + explicit Prelu(schema::Primitive *primitive) : Activation(primitive) {} std::vector GetSlope() const; void SetSlope(const std::vector &slope); diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 826e89ccc06..b98c09e1bd1 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -15,6 +15,7 @@ */ #include "src/ops/primitive_c.h" +#include #include "src/ops/space_to_batch.h" #include "src/ops/conv2d.h" #include "src/ops/roi_pooling.h" @@ -114,29 +115,60 @@ namespace mindspore { namespace lite { -int PrimitiveC::InferShape(std::vector inputs_, std::vector outputs_) { - auto input = inputs_.front(); - MS_ASSERT(input != nullptr); - auto output = outputs_.front(); - MS_ASSERT(output != nullptr); - output->set_shape(input->shape()); - output->set_data_type(input->data_type()); - output->SetFormat(input->GetFormat()); - return 0; -} -int PrimitiveC::Type() const { #ifdef PRIMITIVE_WRITEABLE - return this->primitive->value.type; -#else - return this->primitive->value_type(); -#endif +schema::PrimitiveT *PrimitiveC::GetPrimitiveT() const { return this->primitive_; } + +void PrimitiveC::SetPrimitiveT(schema::PrimitiveT *prim) { this->primitive_ = prim; } + +void PrimitiveC::SetInputQuantParam(const std::vector> &input_quant_param) { + this->input_quant_param_ = input_quant_param; } -bool PrimitiveC::GetInferFlag() const { return this->infer_flag_; } +void PrimitiveC::SetOutputQuantParam(const std::vector> &output_quant_param) { + this->output_quant_param_ = output_quant_param; +} -void PrimitiveC::SetInferFlag(bool flag) { this->infer_flag_ = flag; } +void PrimitiveC::ClearInputOutputQuantParam() { + input_quant_param_.clear(); + output_quant_param_.clear(); +} -PrimitiveC *PrimitiveC::CreatePrimitive(OriginPrimitive *primitive) { +void PrimitiveC::AddInputQuantParam(std::vector quant_param) { + this->input_quant_param_.emplace_back(quant_param); +} +std::vector> PrimitiveC::GetInputQuantParams() const { return input_quant_param_; } + +void PrimitiveC::AddOutputQuantParam(std::vector quant_param) { + this->output_quant_param_.emplace_back(quant_param); +} +std::vector> PrimitiveC::GetOutputQuantParams() const { return output_quant_param_; } + +void PrimitiveC::SetQuantType(schema::QuantType quant_type) { this->quant_type_ = quant_type; } + +schema::QuantType PrimitiveC::GetQuantType() const { return quant_type_; } + +std::shared_ptr GetReturnPrim() { + auto return_primitiveT = new schema::PrimitiveT; + return_primitiveT->value.type = schema::PrimitiveType_Return; + return_primitiveT->value.value = new schema::ReturnT; + return std::make_shared(return_primitiveT); +} + +std::shared_ptr GetMakeTuplePrim() { + auto make_tuple_primitiveT = new schema::PrimitiveT; + make_tuple_primitiveT->value.type = schema::PrimitiveType_MakeTuple; + make_tuple_primitiveT->value.value = new schema::MakeTupleT; + return std::make_shared(make_tuple_primitiveT); +} + +std::shared_ptr GetTupleGetItemPrim() { + auto tuple_get_item_primitiveT = new schema::PrimitiveT(); + tuple_get_item_primitiveT->value.type = schema::PrimitiveType_TupleGetItem; + tuple_get_item_primitiveT->value.value = new schema::TupleGetItemT; + return std::make_shared(tuple_get_item_primitiveT); +} + +PrimitiveC *PrimitiveC::CreatePrimitive(mindspore::schema::Primitive *primitive) { MS_ASSERT(primitive != nullptr); auto op_type = primitive->value_type(); switch (op_type) { @@ -267,5 +299,30 @@ PrimitiveC *PrimitiveC::CreatePrimitive(OriginPrimitive *primitive) { } return nullptr; } + +#endif + +int PrimitiveC::Type() const { +#ifdef PRIMITIVE_WRITEABLE + return this->primitive_->value.type; +#else + return this->primitive_->value_type(); +#endif +} +bool PrimitiveC::GetInferFlag() const { return this->infer_flag_; } + +void PrimitiveC::SetInferFlag(bool flag) { this->infer_flag_ = flag; } + +int PrimitiveC::InferShape(std::vector inputs_, std::vector outputs_) { + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_shape(input->shape()); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + return 0; +} + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index 719e0222e9f..80474b5d57b 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -19,18 +19,18 @@ #include #include #include +#include +#ifdef PRIMITIVE_WRITEABLE +#include "ir/primitive.h" +#include "schema/inner/model_generated.h" +#else +#include "schema/model_generated.h" +#endif + #include "src/ir/tensor.h" #include "include/errorcode.h" #include "utils/log_adapter.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -using OriginPrimitive = mindspore::schema::PrimitiveT; -#else -#include "schema/model_generated.h" -using OriginPrimitive = mindspore::schema::Primitive; -#endif - namespace mindspore { namespace lite { constexpr uint32_t kSingleNum = 1; @@ -40,16 +40,93 @@ constexpr uint32_t kDimension_4d = 4; const std::set kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeFloat16}; -// #if LITE_OPTIMIZE +#ifdef PRIMITIVE_WRITEABLE +class PrimitiveC : public mindspore::Primitive { + public: + explicit PrimitiveC(schema::Primitive *primitive) : Primitive("") { this->primitive_ = primitive->UnPack(); } + + explicit PrimitiveC(schema::PrimitiveT *primitive) : Primitive(""), primitive_(primitive) {} + + explicit PrimitiveC(const Primitive &prim) : Primitive(prim) {} + + explicit PrimitiveC(const std::string &name, schema::PrimitiveT *primitive) + : Primitive(name), primitive_(primitive) {} + + MS_DECLARE_PARENT(PrimitiveC, Primitive); + + ~PrimitiveC() override = default; + + int Type() const; + + // static PrimitiveC *UnPackFromPrimitive(const Primitive &prim); + + schema::PrimitiveT *GetPrimitiveT() const; + + void SetPrimitiveT(schema::PrimitiveT *prim); + + bool operator==(const Value &rhs) const { + if (rhs.isa()) { + auto other_prim = static_cast(rhs); + auto a = this->primitive_->value.type; + auto b = other_prim.primitive_->value.type; + return a == b; + } else { + return false; + } + } + + void SetInputQuantParam(const std::vector> &input_quant_param); + + void SetOutputQuantParam(const std::vector> &output_quant_param); + + void ClearInputOutputQuantParam(); + + void AddInputQuantParam(std::vector quant_param); + + std::vector> GetInputQuantParams() const; + + void AddOutputQuantParam(std::vector quant_param); + + std::vector> GetOutputQuantParams() const; + + void SetQuantType(schema::QuantType quant_type); + + schema::QuantType GetQuantType() const; + + virtual int InferShape(std::vector inputs_, std::vector outputs_); + + bool GetInferFlag() const; + + void SetInferFlag(bool flag); + + static PrimitiveC *CreatePrimitive(mindspore::schema::Primitive *primitive); + + protected: + // virutal PrimitiveC *UnPackAttr(const Primitive &prim) = 0; + + protected: + schema::PrimitiveT *primitive_ = nullptr; + std::vector> input_quant_param_; + std::vector> output_quant_param_; + schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; + bool infer_flag_ = true; +}; +std::shared_ptr GetReturnPrim(); + +std::shared_ptr GetMakeTuplePrim(); + +std::shared_ptr GetTupleGetItemPrim(); + + + +#else class PrimitiveC { public: PrimitiveC() = default; - explicit PrimitiveC(OriginPrimitive *primitive) : primitive(primitive) {} + explicit PrimitiveC(schema::Primitive *primitive) : primitive_(primitive) {} - static PrimitiveC *CreatePrimitive(OriginPrimitive *primitive); - - virtual ~PrimitiveC() {} + virtual ~PrimitiveC() = default; bool GetInferFlag() const; @@ -60,9 +137,10 @@ class PrimitiveC { int Type() const; protected: - OriginPrimitive *primitive; + schema::Primitive *primitive_ = nullptr; bool infer_flag_ = true; }; +#endif } // namespace lite } // namespace mindspore #endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ diff --git a/mindspore/lite/src/ops/prior_box.cc b/mindspore/lite/src/ops/prior_box.cc index ec7cb3d4717..3bad25e0d83 100644 --- a/mindspore/lite/src/ops/prior_box.cc +++ b/mindspore/lite/src/ops/prior_box.cc @@ -19,63 +19,63 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector PriorBox::GetMinSizes() const { return this->primitive->value.AsPriorBox()->max_sizes; } -std::vector PriorBox::GetMaxSizes() const { return this->primitive->value.AsPriorBox()->max_sizes; } -std::vector PriorBox::GetAspectRatios() const { return this->primitive->value.AsPriorBox()->aspect_ratios; } -std::vector PriorBox::GetVariances() const { return this->primitive->value.AsPriorBox()->variances; } -int PriorBox::GetImageSizeW() const { return this->primitive->value.AsPriorBox()->image_size_w; } -int PriorBox::GetImageSizeH() const { return this->primitive->value.AsPriorBox()->image_size_h; } -float PriorBox::GetStepW() const { return this->primitive->value.AsPriorBox()->step_w; } -float PriorBox::GetStepH() const { return this->primitive->value.AsPriorBox()->step_h; } -bool PriorBox::GetClip() const { return this->primitive->value.AsPriorBox()->clip; } -bool PriorBox::GetFlip() const { return this->primitive->value.AsPriorBox()->flip; } -float PriorBox::GetOffset() const { return this->primitive->value.AsPriorBox()->offset; } +std::vector PriorBox::GetMinSizes() const { return this->primitive_->value.AsPriorBox()->max_sizes; } +std::vector PriorBox::GetMaxSizes() const { return this->primitive_->value.AsPriorBox()->max_sizes; } +std::vector PriorBox::GetAspectRatios() const { return this->primitive_->value.AsPriorBox()->aspect_ratios; } +std::vector PriorBox::GetVariances() const { return this->primitive_->value.AsPriorBox()->variances; } +int PriorBox::GetImageSizeW() const { return this->primitive_->value.AsPriorBox()->image_size_w; } +int PriorBox::GetImageSizeH() const { return this->primitive_->value.AsPriorBox()->image_size_h; } +float PriorBox::GetStepW() const { return this->primitive_->value.AsPriorBox()->step_w; } +float PriorBox::GetStepH() const { return this->primitive_->value.AsPriorBox()->step_h; } +bool PriorBox::GetClip() const { return this->primitive_->value.AsPriorBox()->clip; } +bool PriorBox::GetFlip() const { return this->primitive_->value.AsPriorBox()->flip; } +float PriorBox::GetOffset() const { return this->primitive_->value.AsPriorBox()->offset; } void PriorBox::SetMinSizes(const std::vector &min_sizes) { - this->primitive->value.AsPriorBox()->min_sizes = min_sizes; + this->primitive_->value.AsPriorBox()->min_sizes = min_sizes; } void PriorBox::SetMaxSizes(const std::vector &max_sizes) { - this->primitive->value.AsPriorBox()->max_sizes = max_sizes; + this->primitive_->value.AsPriorBox()->max_sizes = max_sizes; } void PriorBox::SetAspectRatios(const std::vector &aspect_ratios) { - this->primitive->value.AsPriorBox()->aspect_ratios = aspect_ratios; + this->primitive_->value.AsPriorBox()->aspect_ratios = aspect_ratios; } void PriorBox::SetVariances(const std::vector &variances) { - this->primitive->value.AsPriorBox()->variances = variances; + this->primitive_->value.AsPriorBox()->variances = variances; } -void PriorBox::SetImageSizeW(int image_size_w) { this->primitive->value.AsPriorBox()->image_size_w = image_size_w; } -void PriorBox::SetImageSizeH(int image_size_h) { this->primitive->value.AsPriorBox()->image_size_h = image_size_h; } -void PriorBox::SetStepW(float step_w) { this->primitive->value.AsPriorBox()->step_w = step_w; } -void PriorBox::SetStepH(float step_h) { this->primitive->value.AsPriorBox()->step_h = step_h; } -void PriorBox::SetClip(bool clip) { this->primitive->value.AsPriorBox()->clip = clip; } -void PriorBox::SetFlip(bool flip) { this->primitive->value.AsPriorBox()->flip = flip; } -void PriorBox::SetOffset(float offset) { this->primitive->value.AsPriorBox()->offset = offset; } +void PriorBox::SetImageSizeW(int image_size_w) { this->primitive_->value.AsPriorBox()->image_size_w = image_size_w; } +void PriorBox::SetImageSizeH(int image_size_h) { this->primitive_->value.AsPriorBox()->image_size_h = image_size_h; } +void PriorBox::SetStepW(float step_w) { this->primitive_->value.AsPriorBox()->step_w = step_w; } +void PriorBox::SetStepH(float step_h) { this->primitive_->value.AsPriorBox()->step_h = step_h; } +void PriorBox::SetClip(bool clip) { this->primitive_->value.AsPriorBox()->clip = clip; } +void PriorBox::SetFlip(bool flip) { this->primitive_->value.AsPriorBox()->flip = flip; } +void PriorBox::SetOffset(float offset) { this->primitive_->value.AsPriorBox()->offset = offset; } #else std::vector PriorBox::GetMinSizes() const { - auto fb_vector = this->primitive->value_as_PriorBox()->min_sizes(); + auto fb_vector = this->primitive_->value_as_PriorBox()->min_sizes(); return std::vector(fb_vector->begin(), fb_vector->end()); } std::vector PriorBox::GetMaxSizes() const { - auto fb_vector = this->primitive->value_as_PriorBox()->max_sizes(); + auto fb_vector = this->primitive_->value_as_PriorBox()->max_sizes(); return std::vector(fb_vector->begin(), fb_vector->end()); } std::vector PriorBox::GetAspectRatios() const { - auto fb_vector = this->primitive->value_as_PriorBox()->aspect_ratios(); + auto fb_vector = this->primitive_->value_as_PriorBox()->aspect_ratios(); return std::vector(fb_vector->begin(), fb_vector->end()); } std::vector PriorBox::GetVariances() const { - auto fb_vector = this->primitive->value_as_PriorBox()->variances(); + auto fb_vector = this->primitive_->value_as_PriorBox()->variances(); return std::vector(fb_vector->begin(), fb_vector->end()); } -int PriorBox::GetImageSizeW() const { return this->primitive->value_as_PriorBox()->image_size_w(); } -int PriorBox::GetImageSizeH() const { return this->primitive->value_as_PriorBox()->image_size_h(); } -float PriorBox::GetStepW() const { return this->primitive->value_as_PriorBox()->step_w(); } -float PriorBox::GetStepH() const { return this->primitive->value_as_PriorBox()->step_h(); } -bool PriorBox::GetClip() const { return this->primitive->value_as_PriorBox()->clip(); } -bool PriorBox::GetFlip() const { return this->primitive->value_as_PriorBox()->flip(); } -float PriorBox::GetOffset() const { return this->primitive->value_as_PriorBox()->offset(); } +int PriorBox::GetImageSizeW() const { return this->primitive_->value_as_PriorBox()->image_size_w(); } +int PriorBox::GetImageSizeH() const { return this->primitive_->value_as_PriorBox()->image_size_h(); } +float PriorBox::GetStepW() const { return this->primitive_->value_as_PriorBox()->step_w(); } +float PriorBox::GetStepH() const { return this->primitive_->value_as_PriorBox()->step_h(); } +bool PriorBox::GetClip() const { return this->primitive_->value_as_PriorBox()->clip(); } +bool PriorBox::GetFlip() const { return this->primitive_->value_as_PriorBox()->flip(); } +float PriorBox::GetOffset() const { return this->primitive_->value_as_PriorBox()->offset(); } void PriorBox::SetMinSizes(const std::vector &min_sizes) {} void PriorBox::SetMaxSizes(const std::vector &max_sizes) {} diff --git a/mindspore/lite/src/ops/prior_box.h b/mindspore/lite/src/ops/prior_box.h index 78eaafc8496..508cd88659a 100644 --- a/mindspore/lite/src/ops/prior_box.h +++ b/mindspore/lite/src/ops/prior_box.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_PRIOR_BOX_H_ +#define LITE_MINDSPORE_LITE_C_OPS_PRIOR_BOX_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_PRIOR_BOX_H_ -#define LITE_MINDSPORE_LITE_C_OPS_PRIOR_BOX_H_ namespace mindspore { namespace lite { class PriorBox : public PrimitiveC { public: - explicit PriorBox(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit PriorBox(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit PriorBox(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetMinSizes() const; diff --git a/mindspore/lite/src/ops/quant_dtype_cast.cc b/mindspore/lite/src/ops/quant_dtype_cast.cc index f3852021041..c29ee9a95ff 100644 --- a/mindspore/lite/src/ops/quant_dtype_cast.cc +++ b/mindspore/lite/src/ops/quant_dtype_cast.cc @@ -19,23 +19,23 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int QuantDTypeCast::GetSrcT() const { return this->primitive->value.AsQuantDTypeCast()->srcT; } -int QuantDTypeCast::GetDstT() const { return this->primitive->value.AsQuantDTypeCast()->dstT; } +int QuantDTypeCast::GetSrcT() const { return this->primitive_->value.AsQuantDTypeCast()->srcT; } +int QuantDTypeCast::GetDstT() const { return this->primitive_->value.AsQuantDTypeCast()->dstT; } -void QuantDTypeCast::SetSrcT(int src_t) { this->primitive->value.AsQuantDTypeCast()->srcT = src_t; } -void QuantDTypeCast::SetDstT(int dst_t) { this->primitive->value.AsQuantDTypeCast()->dstT = dst_t; } +void QuantDTypeCast::SetSrcT(int src_t) { this->primitive_->value.AsQuantDTypeCast()->srcT = src_t; } +void QuantDTypeCast::SetDstT(int dst_t) { this->primitive_->value.AsQuantDTypeCast()->dstT = dst_t; } #else -int QuantDTypeCast::GetSrcT() const { return this->primitive->value_as_QuantDTypeCast()->srcT(); } -int QuantDTypeCast::GetDstT() const { return this->primitive->value_as_QuantDTypeCast()->dstT(); } +int QuantDTypeCast::GetSrcT() const { return this->primitive_->value_as_QuantDTypeCast()->srcT(); } +int QuantDTypeCast::GetDstT() const { return this->primitive_->value_as_QuantDTypeCast()->dstT(); } void QuantDTypeCast::SetSrcT(int src_t) {} void QuantDTypeCast::SetDstT(int dst_t) {} #endif int QuantDTypeCast::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); diff --git a/mindspore/lite/src/ops/quant_dtype_cast.h b/mindspore/lite/src/ops/quant_dtype_cast.h index 2970ac5e379..dbbb689d2fb 100644 --- a/mindspore/lite/src/ops/quant_dtype_cast.h +++ b/mindspore/lite/src/ops/quant_dtype_cast.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_QUANT_D_TYPE_CAST_H_ +#define LITE_MINDSPORE_LITE_C_OPS_QUANT_D_TYPE_CAST_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_QUANT_D_TYPE_CAST_H_ -#define LITE_MINDSPORE_LITE_C_OPS_QUANT_D_TYPE_CAST_H_ namespace mindspore { namespace lite { class QuantDTypeCast : public PrimitiveC { public: - explicit QuantDTypeCast(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit QuantDTypeCast(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit QuantDTypeCast(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetSrcT() const; diff --git a/mindspore/lite/src/ops/range.cc b/mindspore/lite/src/ops/range.cc index b246af74bc4..9c5816b3376 100644 --- a/mindspore/lite/src/ops/range.cc +++ b/mindspore/lite/src/ops/range.cc @@ -19,22 +19,22 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Range::GetDType() const { return this->primitive->value.AsRange()->dType; } -int Range::GetStart() const { return this->primitive->value.AsRange()->start; } -int Range::GetLimit() const { return this->primitive->value.AsRange()->limit; } -int Range::GetDelta() const { return this->primitive->value.AsRange()->delta; } +int Range::GetDType() const { return this->primitive_->value.AsRange()->dType; } +int Range::GetStart() const { return this->primitive_->value.AsRange()->start; } +int Range::GetLimit() const { return this->primitive_->value.AsRange()->limit; } +int Range::GetDelta() const { return this->primitive_->value.AsRange()->delta; } -void Range::SetDType(int d_type) { this->primitive->value.AsRange()->dType = d_type; } -void Range::SetStart(int start) { this->primitive->value.AsRange()->start = start; } -void Range::SetLimit(int limit) { this->primitive->value.AsRange()->limit = limit; } -void Range::SetDelta(int delta) { this->primitive->value.AsRange()->delta = delta; } +void Range::SetDType(int d_type) { this->primitive_->value.AsRange()->dType = d_type; } +void Range::SetStart(int start) { this->primitive_->value.AsRange()->start = start; } +void Range::SetLimit(int limit) { this->primitive_->value.AsRange()->limit = limit; } +void Range::SetDelta(int delta) { this->primitive_->value.AsRange()->delta = delta; } #else -int Range::GetDType() const { return this->primitive->value_as_Range()->dType(); } -int Range::GetStart() const { return this->primitive->value_as_Range()->start(); } -int Range::GetLimit() const { return this->primitive->value_as_Range()->limit(); } -int Range::GetDelta() const { return this->primitive->value_as_Range()->delta(); } +int Range::GetDType() const { return this->primitive_->value_as_Range()->dType(); } +int Range::GetStart() const { return this->primitive_->value_as_Range()->start(); } +int Range::GetLimit() const { return this->primitive_->value_as_Range()->limit(); } +int Range::GetDelta() const { return this->primitive_->value_as_Range()->delta(); } void Range::SetDType(int d_type) {} void Range::SetStart(int start) {} @@ -43,7 +43,7 @@ void Range::SetDelta(int delta) {} #endif int Range::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); diff --git a/mindspore/lite/src/ops/range.h b/mindspore/lite/src/ops/range.h index 237c764b2fc..b7abbc102d0 100644 --- a/mindspore/lite/src/ops/range.h +++ b/mindspore/lite/src/ops/range.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_RANGE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_RANGE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_RANGE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_RANGE_H_ namespace mindspore { namespace lite { class Range : public PrimitiveC { public: - explicit Range(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Range(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Range(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetDType() const; diff --git a/mindspore/lite/src/ops/rank.cc b/mindspore/lite/src/ops/rank.cc index 8057f52b9e9..5a89c681785 100644 --- a/mindspore/lite/src/ops/rank.cc +++ b/mindspore/lite/src/ops/rank.cc @@ -20,7 +20,7 @@ namespace mindspore { namespace lite { int Rank::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); diff --git a/mindspore/lite/src/ops/rank.h b/mindspore/lite/src/ops/rank.h index 3c979f9d543..b56f4560327 100644 --- a/mindspore/lite/src/ops/rank.h +++ b/mindspore/lite/src/ops/rank.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_RANK_H_ +#define LITE_MINDSPORE_LITE_C_OPS_RANK_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_RANK_H_ -#define LITE_MINDSPORE_LITE_C_OPS_RANK_H_ namespace mindspore { namespace lite { class Rank : public PrimitiveC { public: - explicit Rank(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Rank(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Rank(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; }; diff --git a/mindspore/lite/src/ops/reduce.cc b/mindspore/lite/src/ops/reduce.cc index 9e887c9b6a4..2bd709a5007 100644 --- a/mindspore/lite/src/ops/reduce.cc +++ b/mindspore/lite/src/ops/reduce.cc @@ -19,22 +19,22 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector Reduce::GetAxes() const { return this->primitive->value.AsReduce()->axes; } -int Reduce::GetKeepDims() const { return this->primitive->value.AsReduce()->keepDims; } -int Reduce::GetMode() const { return this->primitive->value.AsReduce()->mode; } +std::vector Reduce::GetAxes() const { return this->primitive_->value.AsReduce()->axes; } +int Reduce::GetKeepDims() const { return this->primitive_->value.AsReduce()->keepDims; } +int Reduce::GetMode() const { return this->primitive_->value.AsReduce()->mode; } -void Reduce::SetAxes(const std::vector &axes) { this->primitive->value.AsReduce()->axes = axes; } -void Reduce::SetKeepDims(int keep_dims) { this->primitive->value.AsReduce()->keepDims = keep_dims; } -void Reduce::SetMode(int mode) { this->primitive->value.AsReduce()->mode = (schema::ReduceMode)mode; } +void Reduce::SetAxes(const std::vector &axes) { this->primitive_->value.AsReduce()->axes = axes; } +void Reduce::SetKeepDims(int keep_dims) { this->primitive_->value.AsReduce()->keepDims = keep_dims; } +void Reduce::SetMode(int mode) { this->primitive_->value.AsReduce()->mode = (schema::ReduceMode)mode; } #else std::vector Reduce::GetAxes() const { - auto fb_vector = this->primitive->value_as_Reduce()->axes(); + auto fb_vector = this->primitive_->value_as_Reduce()->axes(); return std::vector(fb_vector->begin(), fb_vector->end()); } -int Reduce::GetKeepDims() const { return this->primitive->value_as_Reduce()->keepDims(); } -int Reduce::GetMode() const { return this->primitive->value_as_Reduce()->mode(); } +int Reduce::GetKeepDims() const { return this->primitive_->value_as_Reduce()->keepDims(); } +int Reduce::GetMode() const { return this->primitive_->value_as_Reduce()->mode(); } void Reduce::SetAxes(const std::vector &axes) {} void Reduce::SetKeepDims(int keep_dims) {} @@ -59,7 +59,7 @@ int Reduce::InferShape(std::vector inputs_, std::vectorprimitive == nullptr) { + if (this->primitive_ == nullptr) { return RET_NULL_PTR; } diff --git a/mindspore/lite/src/ops/reduce.h b/mindspore/lite/src/ops/reduce.h index 5cce3a8fb67..079834742c1 100644 --- a/mindspore/lite/src/ops/reduce.h +++ b/mindspore/lite/src/ops/reduce.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_REDUCE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_REDUCE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_REDUCE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_REDUCE_H_ namespace mindspore { namespace lite { class Reduce : public PrimitiveC { public: - explicit Reduce(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Reduce(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Reduce(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetAxes() const; diff --git a/mindspore/lite/src/ops/reshape.cc b/mindspore/lite/src/ops/reshape.cc index e683beb2358..1694d3154da 100644 --- a/mindspore/lite/src/ops/reshape.cc +++ b/mindspore/lite/src/ops/reshape.cc @@ -23,17 +23,17 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Reshape::GetFormat() const { return this->primitive->value.AsReshape()->format; } -std::vector Reshape::GetShape() const { return this->primitive->value.AsReshape()->shape; } +int Reshape::GetFormat() const { return this->primitive_->value.AsReshape()->format; } +std::vector Reshape::GetShape() const { return this->primitive_->value.AsReshape()->shape; } -void Reshape::SetFormat(int format) { this->primitive->value.AsReshape()->format = (schema::Format) format; } -void Reshape::SetShape(const std::vector &shape) { this->primitive->value.AsReshape()->shape = shape; } +void Reshape::SetFormat(int format) { this->primitive_->value.AsReshape()->format = (schema::Format)format; } +void Reshape::SetShape(const std::vector &shape) { this->primitive_->value.AsReshape()->shape = shape; } #else -int Reshape::GetFormat() const { return this->primitive->value_as_Reshape()->format(); } +int Reshape::GetFormat() const { return this->primitive_->value_as_Reshape()->format(); } std::vector Reshape::GetShape() const { - auto fb_vector = this->primitive->value_as_Reshape()->shape(); + auto fb_vector = this->primitive_->value_as_Reshape()->shape(); return std::vector(fb_vector->begin(), fb_vector->end()); } @@ -75,7 +75,7 @@ int Reshape::CalNewShape(const tensor::Tensor *in_tensor, std::vector *out_ } return RET_OK; } -template +template void CalShape(const T *data, const std::vector &inputs, std::vector *out_shape, int shape_size) { int input_count = inputs[0]->ElementsNum(); int index = 0; @@ -93,7 +93,7 @@ void CalShape(const T *data, const std::vector &inputs, std::v } } int Reshape::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); @@ -117,28 +117,23 @@ int Reshape::InferShape(std::vector inputs_, std::vector(shape_tensor->Data()); CalShape(data, inputs_, &out_shape, shape_size); - } - break; + } break; case kNumberTypeInt32: { auto data = reinterpret_cast(shape_tensor->Data()); CalShape(data, inputs_, &out_shape, shape_size); - } - break; + } break; case kNumberTypeInt64: { auto data = reinterpret_cast(shape_tensor->Data()); CalShape(data, inputs_, &out_shape, shape_size); - } - break; + } break; case kNumberTypeFloat: { auto data = reinterpret_cast(shape_tensor->Data()); CalShape(data, inputs_, &out_shape, shape_size); - } - break; + } break; case kNumberTypeUInt32: { auto data = reinterpret_cast(shape_tensor->Data()); CalShape(data, inputs_, &out_shape, shape_size); - } - break; + } break; default: { MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type(); return RET_INFER_ERR; diff --git a/mindspore/lite/src/ops/reshape.h b/mindspore/lite/src/ops/reshape.h index c7f77974f69..0343d20bd39 100644 --- a/mindspore/lite/src/ops/reshape.h +++ b/mindspore/lite/src/ops/reshape.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_RESHAPE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_RESHAPE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_RESHAPE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_RESHAPE_H_ namespace mindspore { namespace lite { class Reshape : public PrimitiveC { public: - explicit Reshape(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Reshape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Reshape(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; diff --git a/mindspore/lite/src/ops/resize.cc b/mindspore/lite/src/ops/resize.cc index 1ef6c3c2dea..26f9002baec 100644 --- a/mindspore/lite/src/ops/resize.cc +++ b/mindspore/lite/src/ops/resize.cc @@ -19,30 +19,30 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Resize::GetFormat() const { return this->primitive->value.AsResize()->format; } -int Resize::GetMethod() const { return this->primitive->value.AsResize()->method; } -long Resize::GetNewHeight() const { return this->primitive->value.AsResize()->newHeight; } -long Resize::GetNewWidth() const { return this->primitive->value.AsResize()->newWidth; } -bool Resize::GetAlignCorners() const { return this->primitive->value.AsResize()->alignCorners; } -bool Resize::GetPreserveAspectRatio() const { return this->primitive->value.AsResize()->preserveAspectRatio; } +int Resize::GetFormat() const { return this->primitive_->value.AsResize()->format; } +int Resize::GetMethod() const { return this->primitive_->value.AsResize()->method; } +long Resize::GetNewHeight() const { return this->primitive_->value.AsResize()->newHeight; } +long Resize::GetNewWidth() const { return this->primitive_->value.AsResize()->newWidth; } +bool Resize::GetAlignCorners() const { return this->primitive_->value.AsResize()->alignCorners; } +bool Resize::GetPreserveAspectRatio() const { return this->primitive_->value.AsResize()->preserveAspectRatio; } -void Resize::SetFormat(int format) { this->primitive->value.AsResize()->format = (schema::Format)format; } -void Resize::SetMethod(int method) { this->primitive->value.AsResize()->method = (schema::ResizeMethod)method; } -void Resize::SetNewHeight(long new_height) { this->primitive->value.AsResize()->newHeight = new_height; } -void Resize::SetNewWidth(long new_width) { this->primitive->value.AsResize()->newWidth = new_width; } -void Resize::SetAlignCorners(bool align_corners) { this->primitive->value.AsResize()->alignCorners = align_corners; } +void Resize::SetFormat(int format) { this->primitive_->value.AsResize()->format = (schema::Format)format; } +void Resize::SetMethod(int method) { this->primitive_->value.AsResize()->method = (schema::ResizeMethod)method; } +void Resize::SetNewHeight(long new_height) { this->primitive_->value.AsResize()->newHeight = new_height; } +void Resize::SetNewWidth(long new_width) { this->primitive_->value.AsResize()->newWidth = new_width; } +void Resize::SetAlignCorners(bool align_corners) { this->primitive_->value.AsResize()->alignCorners = align_corners; } void Resize::SetPreserveAspectRatio(bool preserve_aspect_ratio) { - this->primitive->value.AsResize()->preserveAspectRatio = preserve_aspect_ratio; + this->primitive_->value.AsResize()->preserveAspectRatio = preserve_aspect_ratio; } #else -int Resize::GetFormat() const { return this->primitive->value_as_Resize()->format(); } -int Resize::GetMethod() const { return this->primitive->value_as_Resize()->method(); } -long Resize::GetNewHeight() const { return this->primitive->value_as_Resize()->newHeight(); } -long Resize::GetNewWidth() const { return this->primitive->value_as_Resize()->newWidth(); } -bool Resize::GetAlignCorners() const { return this->primitive->value_as_Resize()->alignCorners(); } -bool Resize::GetPreserveAspectRatio() const { return this->primitive->value_as_Resize()->preserveAspectRatio(); } +int Resize::GetFormat() const { return this->primitive_->value_as_Resize()->format(); } +int Resize::GetMethod() const { return this->primitive_->value_as_Resize()->method(); } +long Resize::GetNewHeight() const { return this->primitive_->value_as_Resize()->newHeight(); } +long Resize::GetNewWidth() const { return this->primitive_->value_as_Resize()->newWidth(); } +bool Resize::GetAlignCorners() const { return this->primitive_->value_as_Resize()->alignCorners(); } +bool Resize::GetPreserveAspectRatio() const { return this->primitive_->value_as_Resize()->preserveAspectRatio(); } void Resize::SetFormat(int format) {} void Resize::SetMethod(int method) {} @@ -55,7 +55,7 @@ namespace { constexpr int kInputRank = 4; } // namespace int Resize::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); if (input == nullptr) { return 1; diff --git a/mindspore/lite/src/ops/resize.h b/mindspore/lite/src/ops/resize.h index 0f365a7c06f..697aa360081 100644 --- a/mindspore/lite/src/ops/resize.h +++ b/mindspore/lite/src/ops/resize.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_RESIZE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_RESIZE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_RESIZE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_RESIZE_H_ namespace mindspore { namespace lite { class Resize : public PrimitiveC { public: - explicit Resize(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Resize(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Resize(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; diff --git a/mindspore/lite/src/ops/reverse.cc b/mindspore/lite/src/ops/reverse.cc index ad941dc705e..b4a56286c4b 100644 --- a/mindspore/lite/src/ops/reverse.cc +++ b/mindspore/lite/src/ops/reverse.cc @@ -19,14 +19,14 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector Reverse::GetAxis() const { return this->primitive->value.AsReverse()->axis; } +std::vector Reverse::GetAxis() const { return this->primitive_->value.AsReverse()->axis; } -void Reverse::SetAxis(const std::vector &axis) { this->primitive->value.AsReverse()->axis = axis; } +void Reverse::SetAxis(const std::vector &axis) { this->primitive_->value.AsReverse()->axis = axis; } #else std::vector Reverse::GetAxis() const { - auto fb_vector = this->primitive->value_as_Reverse()->axis(); + auto fb_vector = this->primitive_->value_as_Reverse()->axis(); return std::vector(fb_vector->begin(), fb_vector->end()); } diff --git a/mindspore/lite/src/ops/reverse.h b/mindspore/lite/src/ops/reverse.h index 44c45c68435..9ac0ec7ce40 100644 --- a/mindspore/lite/src/ops/reverse.h +++ b/mindspore/lite/src/ops/reverse.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_REVERSE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_REVERSE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_REVERSE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_REVERSE_H_ namespace mindspore { namespace lite { class Reverse : public PrimitiveC { public: - explicit Reverse(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Reverse(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Reverse(schema::Primitive *primitive) : PrimitiveC(primitive) {} std::vector GetAxis() const; void SetAxis(const std::vector &axis); diff --git a/mindspore/lite/src/ops/reverse_sequence.cc b/mindspore/lite/src/ops/reverse_sequence.cc index fc7179d09ab..e362b98c249 100644 --- a/mindspore/lite/src/ops/reverse_sequence.cc +++ b/mindspore/lite/src/ops/reverse_sequence.cc @@ -19,26 +19,26 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int ReverseSequence::GetSeqAxis() const { return this->primitive->value.AsReverseSequence()->seqAxis; } -int ReverseSequence::GetBatchAxis() const { return this->primitive->value.AsReverseSequence()->batchAxis; } +int ReverseSequence::GetSeqAxis() const { return this->primitive_->value.AsReverseSequence()->seqAxis; } +int ReverseSequence::GetBatchAxis() const { return this->primitive_->value.AsReverseSequence()->batchAxis; } std::vector ReverseSequence::GetSeqLengths() const { - return this->primitive->value.AsReverseSequence()->seqLengths; + return this->primitive_->value.AsReverseSequence()->seqLengths; } -void ReverseSequence::SetSeqAxis(int seq_axis) { this->primitive->value.AsReverseSequence()->seqAxis = seq_axis; } +void ReverseSequence::SetSeqAxis(int seq_axis) { this->primitive_->value.AsReverseSequence()->seqAxis = seq_axis; } void ReverseSequence::SetBatchAxis(int batch_axis) { - this->primitive->value.AsReverseSequence()->batchAxis = batch_axis; + this->primitive_->value.AsReverseSequence()->batchAxis = batch_axis; } void ReverseSequence::SetSeqLengths(const std::vector &seq_lengths) { - this->primitive->value.AsReverseSequence()->seqLengths = seq_lengths; + this->primitive_->value.AsReverseSequence()->seqLengths = seq_lengths; } #else -int ReverseSequence::GetSeqAxis() const { return this->primitive->value_as_ReverseSequence()->seqAxis(); } -int ReverseSequence::GetBatchAxis() const { return this->primitive->value_as_ReverseSequence()->batchAxis(); } +int ReverseSequence::GetSeqAxis() const { return this->primitive_->value_as_ReverseSequence()->seqAxis(); } +int ReverseSequence::GetBatchAxis() const { return this->primitive_->value_as_ReverseSequence()->batchAxis(); } std::vector ReverseSequence::GetSeqLengths() const { - auto fb_vector = this->primitive->value_as_ReverseSequence()->seqLengths(); + auto fb_vector = this->primitive_->value_as_ReverseSequence()->seqLengths(); return std::vector(fb_vector->begin(), fb_vector->end()); } diff --git a/mindspore/lite/src/ops/reverse_sequence.h b/mindspore/lite/src/ops/reverse_sequence.h index 95c798da020..67e543a9d0f 100644 --- a/mindspore/lite/src/ops/reverse_sequence.h +++ b/mindspore/lite/src/ops/reverse_sequence.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_REVERSE_SEQUENCE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_REVERSE_SEQUENCE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_REVERSE_SEQUENCE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_REVERSE_SEQUENCE_H_ namespace mindspore { namespace lite { class ReverseSequence : public PrimitiveC { public: - explicit ReverseSequence(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit ReverseSequence(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit ReverseSequence(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetSeqAxis() const; diff --git a/mindspore/lite/src/ops/roi_pooling.cc b/mindspore/lite/src/ops/roi_pooling.cc index afd2dde720a..03edeb26393 100644 --- a/mindspore/lite/src/ops/roi_pooling.cc +++ b/mindspore/lite/src/ops/roi_pooling.cc @@ -19,19 +19,19 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int ROIPooling::GetPooledH() const { return this->primitive->value.AsROIPooling()->pooledH; } -int ROIPooling::GetPooledW() const { return this->primitive->value.AsROIPooling()->pooledW; } -float ROIPooling::GetScale() const { return this->primitive->value.AsROIPooling()->scale; } +int ROIPooling::GetPooledH() const { return this->primitive_->value.AsROIPooling()->pooledH; } +int ROIPooling::GetPooledW() const { return this->primitive_->value.AsROIPooling()->pooledW; } +float ROIPooling::GetScale() const { return this->primitive_->value.AsROIPooling()->scale; } -void ROIPooling::SetPooledH(int pooled_h) { this->primitive->value.AsROIPooling()->pooledH = pooled_h; } -void ROIPooling::SetPooledW(int pooled_w) { this->primitive->value.AsROIPooling()->pooledW = pooled_w; } -void ROIPooling::SetScale(float scale) { this->primitive->value.AsROIPooling()->scale = scale; } +void ROIPooling::SetPooledH(int pooled_h) { this->primitive_->value.AsROIPooling()->pooledH = pooled_h; } +void ROIPooling::SetPooledW(int pooled_w) { this->primitive_->value.AsROIPooling()->pooledW = pooled_w; } +void ROIPooling::SetScale(float scale) { this->primitive_->value.AsROIPooling()->scale = scale; } #else -int ROIPooling::GetPooledH() const { return this->primitive->value_as_ROIPooling()->pooledH(); } -int ROIPooling::GetPooledW() const { return this->primitive->value_as_ROIPooling()->pooledW(); } -float ROIPooling::GetScale() const { return this->primitive->value_as_ROIPooling()->scale(); } +int ROIPooling::GetPooledH() const { return this->primitive_->value_as_ROIPooling()->pooledH(); } +int ROIPooling::GetPooledW() const { return this->primitive_->value_as_ROIPooling()->pooledW(); } +float ROIPooling::GetScale() const { return this->primitive_->value_as_ROIPooling()->scale(); } void ROIPooling::SetPooledH(int pooled_h) {} void ROIPooling::SetPooledW(int pooled_w) {} @@ -39,7 +39,7 @@ void ROIPooling::SetScale(float scale) {} #endif int ROIPooling::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); if (inputs_.size() != kDoubleNum) { MS_LOG(ERROR) << "inputs number is not equal to " << kDoubleNum; return RET_ERROR; diff --git a/mindspore/lite/src/ops/roi_pooling.h b/mindspore/lite/src/ops/roi_pooling.h index 9daa506f0b1..f3085f43898 100644 --- a/mindspore/lite/src/ops/roi_pooling.h +++ b/mindspore/lite/src/ops/roi_pooling.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_R_O_I_POOLING_H_ +#define LITE_MINDSPORE_LITE_C_OPS_R_O_I_POOLING_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_R_O_I_POOLING_H_ -#define LITE_MINDSPORE_LITE_C_OPS_R_O_I_POOLING_H_ namespace mindspore { namespace lite { class ROIPooling : public PrimitiveC { public: - explicit ROIPooling(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit ROIPooling(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit ROIPooling(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetPooledH() const; diff --git a/mindspore/lite/src/ops/round.h b/mindspore/lite/src/ops/round.h index 48740d94129..b9c1fef1f1a 100644 --- a/mindspore/lite/src/ops/round.h +++ b/mindspore/lite/src/ops/round.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_ROUND_H_ +#define LITE_MINDSPORE_LITE_C_OPS_ROUND_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/arithmetic_self.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_ROUND_H_ -#define LITE_MINDSPORE_LITE_C_OPS_ROUND_H_ +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { class Round : public ArithmeticSelf { public: - explicit Round(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Round(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} +#endif + explicit Round(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/rsqrt.h b/mindspore/lite/src/ops/rsqrt.h index 6f27497e936..396ca2e187e 100644 --- a/mindspore/lite/src/ops/rsqrt.h +++ b/mindspore/lite/src/ops/rsqrt.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_RSQRT_H_ +#define LITE_MINDSPORE_LITE_C_OPS_RSQRT_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/arithmetic_self.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_RSQRT_H_ -#define LITE_MINDSPORE_LITE_C_OPS_RSQRT_H_ +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { class Rsqrt : public ArithmeticSelf { public: - explicit Rsqrt(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Rsqrt(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} +#endif + explicit Rsqrt(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/scale.cc b/mindspore/lite/src/ops/scale.cc index 13cefbe69f3..23abc731adc 100644 --- a/mindspore/lite/src/ops/scale.cc +++ b/mindspore/lite/src/ops/scale.cc @@ -19,13 +19,13 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Scale::GetAxis() const { return this->primitive->value.AsScale()->axis; } +int Scale::GetAxis() const { return this->primitive_->value.AsScale()->axis; } -void Scale::SetAxis(int axis) { this->primitive->value.AsScale()->axis = axis; } +void Scale::SetAxis(int axis) { this->primitive_->value.AsScale()->axis = axis; } #else -int Scale::GetAxis() const { return this->primitive->value_as_Scale()->axis(); } +int Scale::GetAxis() const { return this->primitive_->value_as_Scale()->axis(); } void Scale::SetAxis(int axis) {} #endif diff --git a/mindspore/lite/src/ops/scale.h b/mindspore/lite/src/ops/scale.h index 2940cb4700a..10e68c80c1a 100644 --- a/mindspore/lite/src/ops/scale.h +++ b/mindspore/lite/src/ops/scale.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_SCALE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SCALE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_SCALE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SCALE_H_ namespace mindspore { namespace lite { class Scale : public PrimitiveC { public: - explicit Scale(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Scale(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Scale(schema::Primitive *primitive) : PrimitiveC(primitive) {} int GetAxis() const; void SetAxis(int axis); diff --git a/mindspore/lite/src/ops/scatter_nd.h b/mindspore/lite/src/ops/scatter_nd.h index cd5eef9b9f5..83cdf9fd0a9 100644 --- a/mindspore/lite/src/ops/scatter_nd.h +++ b/mindspore/lite/src/ops/scatter_nd.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_SCATTER_ND_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SCATTER_ND_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_SCATTER_ND_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SCATTER_ND_H_ namespace mindspore { namespace lite { class ScatterND : public PrimitiveC { public: - explicit ScatterND(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit ScatterND(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit ScatterND(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; }; diff --git a/mindspore/lite/src/ops/shape.h b/mindspore/lite/src/ops/shape.h index f96e8cb7ea2..1178f4d9cca 100644 --- a/mindspore/lite/src/ops/shape.h +++ b/mindspore/lite/src/ops/shape.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_SHAPE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SHAPE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_SHAPE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SHAPE_H_ namespace mindspore { namespace lite { class Shape : public PrimitiveC { public: - explicit Shape(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Shape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Shape(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; }; diff --git a/mindspore/lite/src/ops/sin.h b/mindspore/lite/src/ops/sin.h index d9753f78717..0e48f50178b 100644 --- a/mindspore/lite/src/ops/sin.h +++ b/mindspore/lite/src/ops/sin.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_SIN_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SIN_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/arithmetic_self.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_SIN_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SIN_H_ +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { class Sin : public ArithmeticSelf { public: - explicit Sin(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Sin(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} +#endif + explicit Sin(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/slice.cc b/mindspore/lite/src/ops/slice.cc index 80f52d59f55..13008d0c937 100644 --- a/mindspore/lite/src/ops/slice.cc +++ b/mindspore/lite/src/ops/slice.cc @@ -26,23 +26,23 @@ constexpr int kSliceInputNum = 1; constexpr int kSliceOutputNum = 1; } // namespace #ifdef PRIMITIVE_WRITEABLE -int SliceOp::GetFormat() const { return this->primitive->value.AsSlice()->format; } -std::vector SliceOp::GetBegin() const { return this->primitive->value.AsSlice()->begin; } -std::vector SliceOp::GetSize() const { return this->primitive->value.AsSlice()->size; } +int SliceOp::GetFormat() const { return this->primitive_->value.AsSlice()->format; } +std::vector SliceOp::GetBegin() const { return this->primitive_->value.AsSlice()->begin; } +std::vector SliceOp::GetSize() const { return this->primitive_->value.AsSlice()->size; } -void SliceOp::SetFormat(int format) { this->primitive->value.AsSlice()->format = (schema::Format)format; } -void SliceOp::SetBegin(const std::vector &begin) { this->primitive->value.AsSlice()->begin = begin; } -void SliceOp::SetSize(const std::vector &size) { this->primitive->value.AsSlice()->size = size; } +void SliceOp::SetFormat(int format) { this->primitive_->value.AsSlice()->format = (schema::Format)format; } +void SliceOp::SetBegin(const std::vector &begin) { this->primitive_->value.AsSlice()->begin = begin; } +void SliceOp::SetSize(const std::vector &size) { this->primitive_->value.AsSlice()->size = size; } #else -int SliceOp::GetFormat() const { return this->primitive->value_as_Slice()->format(); } +int SliceOp::GetFormat() const { return this->primitive_->value_as_Slice()->format(); } std::vector SliceOp::GetBegin() const { - auto fb_vector = this->primitive->value_as_Slice()->begin(); + auto fb_vector = this->primitive_->value_as_Slice()->begin(); return std::vector(fb_vector->begin(), fb_vector->end()); } std::vector SliceOp::GetSize() const { - auto fb_vector = this->primitive->value_as_Slice()->size(); + auto fb_vector = this->primitive_->value_as_Slice()->size(); return std::vector(fb_vector->begin(), fb_vector->end()); } @@ -52,7 +52,7 @@ void SliceOp::SetSize(const std::vector &size) {} #endif int SliceOp::InferShape(std::vector inputs, std::vector outputs) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); if (inputs.size() != kSliceInputNum || outputs.size() != kSliceOutputNum) { MS_LOG(ERROR) << "input size:" << inputs.size() << ",output size:" << outputs.size(); return RET_PARAM_INVALID; diff --git a/mindspore/lite/src/ops/slice.h b/mindspore/lite/src/ops/slice.h index 2710a321baa..769de4dbd3a 100644 --- a/mindspore/lite/src/ops/slice.h +++ b/mindspore/lite/src/ops/slice.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_SLICE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SLICE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_SLICE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SLICE_H_ namespace mindspore { namespace lite { class SliceOp : public PrimitiveC { public: - explicit SliceOp(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit SliceOp(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit SliceOp(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; diff --git a/mindspore/lite/src/ops/softmax.cc b/mindspore/lite/src/ops/softmax.cc index 338f0a53752..6d56cefe5c0 100644 --- a/mindspore/lite/src/ops/softmax.cc +++ b/mindspore/lite/src/ops/softmax.cc @@ -19,19 +19,19 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int SoftMax::GetAxis() const { return this->primitive->value.AsSoftMax()->axis; } +int SoftMax::GetAxis() const { return this->primitive_->value.AsSoftMax()->axis; } -void SoftMax::SetAxis(int axis) { this->primitive->value.AsSoftMax()->axis = axis; } +void SoftMax::SetAxis(int axis) { this->primitive_->value.AsSoftMax()->axis = axis; } #else -int SoftMax::GetAxis() const { return this->primitive->value_as_SoftMax()->axis(); } +int SoftMax::GetAxis() const { return this->primitive_->value_as_SoftMax()->axis(); } void SoftMax::SetAxis(int axis) {} #endif int SoftMax::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); diff --git a/mindspore/lite/src/ops/softmax.h b/mindspore/lite/src/ops/softmax.h index 96b439edf3e..a77e552eb00 100644 --- a/mindspore/lite/src/ops/softmax.h +++ b/mindspore/lite/src/ops/softmax.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_SOFT_MAX_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SOFT_MAX_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_SOFT_MAX_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SOFT_MAX_H_ namespace mindspore { namespace lite { class SoftMax : public PrimitiveC { public: - explicit SoftMax(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit SoftMax(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit SoftMax(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; diff --git a/mindspore/lite/src/ops/softmax_cross_entropy.cc b/mindspore/lite/src/ops/softmax_cross_entropy.cc index 9e647e5c676..8e863ba30ff 100644 --- a/mindspore/lite/src/ops/softmax_cross_entropy.cc +++ b/mindspore/lite/src/ops/softmax_cross_entropy.cc @@ -19,16 +19,16 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector SoftmaxCrossEntropy::GetAxis() const { return this->primitive->value.AsSoftmaxCrossEntropy()->axis; } +std::vector SoftmaxCrossEntropy::GetAxis() const { return this->primitive_->value.AsSoftmaxCrossEntropy()->axis; } void SoftmaxCrossEntropy::SetAxis(const std::vector &axis) { - this->primitive->value.AsSoftmaxCrossEntropy()->axis = axis; + this->primitive_->value.AsSoftmaxCrossEntropy()->axis = axis; } #else std::vector SoftmaxCrossEntropy::GetAxis() const { - auto fb_vector = this->primitive->value_as_SoftmaxCrossEntropy()->axis(); + auto fb_vector = this->primitive_->value_as_SoftmaxCrossEntropy()->axis(); return std::vector(fb_vector->begin(), fb_vector->end()); } diff --git a/mindspore/lite/src/ops/softmax_cross_entropy.h b/mindspore/lite/src/ops/softmax_cross_entropy.h index d3816cdd88a..5bd160b8e17 100644 --- a/mindspore/lite/src/ops/softmax_cross_entropy.h +++ b/mindspore/lite/src/ops/softmax_cross_entropy.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_SOFTMAX_CROSS_ENTROPY_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SOFTMAX_CROSS_ENTROPY_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_SOFTMAX_CROSS_ENTROPY_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SOFTMAX_CROSS_ENTROPY_H_ namespace mindspore { namespace lite { class SoftmaxCrossEntropy : public PrimitiveC { public: - explicit SoftmaxCrossEntropy(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit SoftmaxCrossEntropy(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit SoftmaxCrossEntropy(schema::Primitive *primitive) : PrimitiveC(primitive) {} std::vector GetAxis() const; void SetAxis(const std::vector &axis); diff --git a/mindspore/lite/src/ops/space_to_batch.cc b/mindspore/lite/src/ops/space_to_batch.cc index db7bd7dd77d..2b92c0b8638 100644 --- a/mindspore/lite/src/ops/space_to_batch.cc +++ b/mindspore/lite/src/ops/space_to_batch.cc @@ -20,24 +20,24 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector SpaceToBatch::GetBlockShape() const { return this->primitive->value.AsSpaceToBatch()->blockShape; } -std::vector SpaceToBatch::GetPaddings() const { return this->primitive->value.AsSpaceToBatch()->paddings; } +std::vector SpaceToBatch::GetBlockShape() const { return this->primitive_->value.AsSpaceToBatch()->blockShape; } +std::vector SpaceToBatch::GetPaddings() const { return this->primitive_->value.AsSpaceToBatch()->paddings; } void SpaceToBatch::SetBlockShape(const std::vector &block_shape) { - this->primitive->value.AsSpaceToBatch()->blockShape = block_shape; + this->primitive_->value.AsSpaceToBatch()->blockShape = block_shape; } void SpaceToBatch::SetPaddings(const std::vector &paddings) { - this->primitive->value.AsSpaceToBatch()->paddings = paddings; + this->primitive_->value.AsSpaceToBatch()->paddings = paddings; } #else std::vector SpaceToBatch::GetBlockShape() const { - auto fb_vector = this->primitive->value_as_SpaceToBatch()->blockShape(); + auto fb_vector = this->primitive_->value_as_SpaceToBatch()->blockShape(); return std::vector(fb_vector->begin(), fb_vector->end()); } std::vector SpaceToBatch::GetPaddings() const { - auto fb_vector = this->primitive->value_as_SpaceToBatch()->paddings(); + auto fb_vector = this->primitive_->value_as_SpaceToBatch()->paddings(); return std::vector(fb_vector->begin(), fb_vector->end()); } @@ -52,7 +52,7 @@ constexpr int kPaddingsSize = 4; } // namespace int SpaceToBatch::InferShape(std::vector inputs, std::vector outputs) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); if (outputs.size() != kSpaceToBatchNDOutputNum || inputs.size() != kSpaceToBatchNDInputNum) { MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); return 1; diff --git a/mindspore/lite/src/ops/space_to_batch.h b/mindspore/lite/src/ops/space_to_batch.h index 4402dc3ddd5..e5afd6bb6c9 100644 --- a/mindspore/lite/src/ops/space_to_batch.h +++ b/mindspore/lite/src/ops/space_to_batch.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_SPACE_TO_BATCH_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SPACE_TO_BATCH_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_SPACE_TO_BATCH_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SPACE_TO_BATCH_H_ namespace mindspore { namespace lite { class SpaceToBatch : public PrimitiveC { public: - explicit SpaceToBatch(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit SpaceToBatch(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit SpaceToBatch(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetBlockShape() const; diff --git a/mindspore/lite/src/ops/space_to_batch_nd.cc b/mindspore/lite/src/ops/space_to_batch_nd.cc index 2f6e14dddea..ff9a32ff5f8 100644 --- a/mindspore/lite/src/ops/space_to_batch_nd.cc +++ b/mindspore/lite/src/ops/space_to_batch_nd.cc @@ -19,24 +19,26 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector SpaceToBatchND::GetBlockShape() const { return this->primitive->value.AsSpaceToBatchND()->blockShape; } -std::vector SpaceToBatchND::GetPaddings() const { return this->primitive->value.AsSpaceToBatchND()->paddings; } +std::vector SpaceToBatchND::GetBlockShape() const { + return this->primitive_->value.AsSpaceToBatchND()->blockShape; +} +std::vector SpaceToBatchND::GetPaddings() const { return this->primitive_->value.AsSpaceToBatchND()->paddings; } void SpaceToBatchND::SetBlockShape(const std::vector &block_shape) { - this->primitive->value.AsSpaceToBatchND()->blockShape = block_shape; + this->primitive_->value.AsSpaceToBatchND()->blockShape = block_shape; } void SpaceToBatchND::SetPaddings(const std::vector &paddings) { - this->primitive->value.AsSpaceToBatchND()->paddings = paddings; + this->primitive_->value.AsSpaceToBatchND()->paddings = paddings; } #else std::vector SpaceToBatchND::GetBlockShape() const { - auto fb_vector = this->primitive->value_as_SpaceToBatchND()->blockShape(); + auto fb_vector = this->primitive_->value_as_SpaceToBatchND()->blockShape(); return std::vector(fb_vector->begin(), fb_vector->end()); } std::vector SpaceToBatchND::GetPaddings() const { - auto fb_vector = this->primitive->value_as_SpaceToBatchND()->paddings(); + auto fb_vector = this->primitive_->value_as_SpaceToBatchND()->paddings(); return std::vector(fb_vector->begin(), fb_vector->end()); } diff --git a/mindspore/lite/src/ops/space_to_batch_nd.h b/mindspore/lite/src/ops/space_to_batch_nd.h index c02a20490a4..821ea1c2ea7 100644 --- a/mindspore/lite/src/ops/space_to_batch_nd.h +++ b/mindspore/lite/src/ops/space_to_batch_nd.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_SPACE_TO_BATCH_N_D_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SPACE_TO_BATCH_N_D_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_SPACE_TO_BATCH_N_D_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SPACE_TO_BATCH_N_D_H_ namespace mindspore { namespace lite { class SpaceToBatchND : public PrimitiveC { public: - explicit SpaceToBatchND(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit SpaceToBatchND(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit SpaceToBatchND(schema::Primitive *primitive) : PrimitiveC(primitive) {} std::vector GetBlockShape() const; std::vector GetPaddings() const; diff --git a/mindspore/lite/src/ops/space_to_depth.cc b/mindspore/lite/src/ops/space_to_depth.cc index 8df60481632..22c35a5da05 100644 --- a/mindspore/lite/src/ops/space_to_depth.cc +++ b/mindspore/lite/src/ops/space_to_depth.cc @@ -20,16 +20,16 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int SpaceToDepth::GetBlockSize() const { return this->primitive->value.AsSpaceToDepth()->blockSize; } -int SpaceToDepth::GetFormat() const { return this->primitive->value.AsSpaceToDepth()->format; } +int SpaceToDepth::GetBlockSize() const { return this->primitive_->value.AsSpaceToDepth()->blockSize; } +int SpaceToDepth::GetFormat() const { return this->primitive_->value.AsSpaceToDepth()->format; } -void SpaceToDepth::SetBlockSize(int block_size) { this->primitive->value.AsSpaceToDepth()->blockSize = block_size; } -void SpaceToDepth::SetFormat(int format) { this->primitive->value.AsSpaceToDepth()->format = (schema::Format)format; } +void SpaceToDepth::SetBlockSize(int block_size) { this->primitive_->value.AsSpaceToDepth()->blockSize = block_size; } +void SpaceToDepth::SetFormat(int format) { this->primitive_->value.AsSpaceToDepth()->format = (schema::Format)format; } #else -int SpaceToDepth::GetBlockSize() const { return this->primitive->value_as_SpaceToDepth()->blockSize(); } -int SpaceToDepth::GetFormat() const { return this->primitive->value_as_SpaceToDepth()->format(); } +int SpaceToDepth::GetBlockSize() const { return this->primitive_->value_as_SpaceToDepth()->blockSize(); } +int SpaceToDepth::GetFormat() const { return this->primitive_->value_as_SpaceToDepth()->format(); } void SpaceToDepth::SetBlockSize(int block_size) {} void SpaceToDepth::SetFormat(int format) {} @@ -40,7 +40,7 @@ constexpr int kSpaceToDepthInputNum = 1; } // namespace int SpaceToDepth::InferShape(std::vector inputs, std::vector outputs) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); if (outputs.size() != kSpaceToDepthOutputNum || inputs.size() != kSpaceToDepthInputNum) { MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); return 1; diff --git a/mindspore/lite/src/ops/space_to_depth.h b/mindspore/lite/src/ops/space_to_depth.h index eddc3e4e2c1..9f374f3b8de 100644 --- a/mindspore/lite/src/ops/space_to_depth.h +++ b/mindspore/lite/src/ops/space_to_depth.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_SPACE_TO_DEPTH_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SPACE_TO_DEPTH_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_SPACE_TO_DEPTH_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SPACE_TO_DEPTH_H_ namespace mindspore { namespace lite { class SpaceToDepth : public PrimitiveC { public: - explicit SpaceToDepth(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit SpaceToDepth(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit SpaceToDepth(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetBlockSize() const; diff --git a/mindspore/lite/src/ops/sparse_to_dense.cc b/mindspore/lite/src/ops/sparse_to_dense.cc index 8b5aab94f1b..85a978a2f20 100644 --- a/mindspore/lite/src/ops/sparse_to_dense.cc +++ b/mindspore/lite/src/ops/sparse_to_dense.cc @@ -19,41 +19,45 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector SparseToDense::GetOutputShape() const { return this->primitive->value.AsSparseToDense()->outputShape; } -std::vector SparseToDense::GetSparseValue() const { return this->primitive->value.AsSparseToDense()->sparseValue; } -std::vector SparseToDense::GetDefaultValue() const { - return this->primitive->value.AsSparseToDense()->defaultValue; +std::vector SparseToDense::GetOutputShape() const { + return this->primitive_->value.AsSparseToDense()->outputShape; } -bool SparseToDense::GetValidateIndices() const { return this->primitive->value.AsSparseToDense()->validateIndices; } +std::vector SparseToDense::GetSparseValue() const { + return this->primitive_->value.AsSparseToDense()->sparseValue; +} +std::vector SparseToDense::GetDefaultValue() const { + return this->primitive_->value.AsSparseToDense()->defaultValue; +} +bool SparseToDense::GetValidateIndices() const { return this->primitive_->value.AsSparseToDense()->validateIndices; } void SparseToDense::SetOutputShape(const std::vector &output_shape) { - this->primitive->value.AsSparseToDense()->outputShape = output_shape; + this->primitive_->value.AsSparseToDense()->outputShape = output_shape; } void SparseToDense::SetSparseValue(const std::vector &sparse_value) { - this->primitive->value.AsSparseToDense()->sparseValue = sparse_value; + this->primitive_->value.AsSparseToDense()->sparseValue = sparse_value; } void SparseToDense::SetDefaultValue(const std::vector &default_value) { - this->primitive->value.AsSparseToDense()->defaultValue = default_value; + this->primitive_->value.AsSparseToDense()->defaultValue = default_value; } void SparseToDense::SetValidateIndices(bool validate_indices) { - this->primitive->value.AsSparseToDense()->validateIndices = validate_indices; + this->primitive_->value.AsSparseToDense()->validateIndices = validate_indices; } #else std::vector SparseToDense::GetOutputShape() const { - auto fb_vector = this->primitive->value_as_SparseToDense()->outputShape(); + auto fb_vector = this->primitive_->value_as_SparseToDense()->outputShape(); return std::vector(fb_vector->begin(), fb_vector->end()); } std::vector SparseToDense::GetSparseValue() const { - auto fb_vector = this->primitive->value_as_SparseToDense()->sparseValue(); + auto fb_vector = this->primitive_->value_as_SparseToDense()->sparseValue(); return std::vector(fb_vector->begin(), fb_vector->end()); } std::vector SparseToDense::GetDefaultValue() const { - auto fb_vector = this->primitive->value_as_SparseToDense()->defaultValue(); + auto fb_vector = this->primitive_->value_as_SparseToDense()->defaultValue(); return std::vector(fb_vector->begin(), fb_vector->end()); } -bool SparseToDense::GetValidateIndices() const { return this->primitive->value_as_SparseToDense()->validateIndices(); } +bool SparseToDense::GetValidateIndices() const { return this->primitive_->value_as_SparseToDense()->validateIndices(); } void SparseToDense::SetOutputShape(const std::vector &output_shape) {} void SparseToDense::SetSparseValue(const std::vector &sparse_value) {} diff --git a/mindspore/lite/src/ops/sparse_to_dense.h b/mindspore/lite/src/ops/sparse_to_dense.h index 07b2094365c..c35663853de 100644 --- a/mindspore/lite/src/ops/sparse_to_dense.h +++ b/mindspore/lite/src/ops/sparse_to_dense.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_SPARSE_TO_DENSE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SPARSE_TO_DENSE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_SPARSE_TO_DENSE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SPARSE_TO_DENSE_H_ namespace mindspore { namespace lite { class SparseToDense : public PrimitiveC { public: - explicit SparseToDense(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit SparseToDense(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit SparseToDense(schema::Primitive *primitive) : PrimitiveC(primitive) {} std::vector GetOutputShape() const; std::vector GetSparseValue() const; diff --git a/mindspore/lite/src/ops/split.cc b/mindspore/lite/src/ops/split.cc index a9a6dd082ca..6df48138a04 100644 --- a/mindspore/lite/src/ops/split.cc +++ b/mindspore/lite/src/ops/split.cc @@ -19,24 +19,24 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Split::GetNumberSplit() const { return this->primitive->value.AsSplit()->numberSplit; } -std::vector Split::GetSizeSplits() const { return this->primitive->value.AsSplit()->sizeSplits; } -int Split::GetSplitDim() const { return this->primitive->value.AsSplit()->splitDim; } +int Split::GetNumberSplit() const { return this->primitive_->value.AsSplit()->numberSplit; } +std::vector Split::GetSizeSplits() const { return this->primitive_->value.AsSplit()->sizeSplits; } +int Split::GetSplitDim() const { return this->primitive_->value.AsSplit()->splitDim; } -void Split::SetNumberSplit(int number_split) { this->primitive->value.AsSplit()->numberSplit = number_split; } +void Split::SetNumberSplit(int number_split) { this->primitive_->value.AsSplit()->numberSplit = number_split; } void Split::SetSizeSplits(const std::vector &size_splits) { - this->primitive->value.AsSplit()->sizeSplits = size_splits; + this->primitive_->value.AsSplit()->sizeSplits = size_splits; } -void Split::SetSplitDim(int split_dim) { this->primitive->value.AsSplit()->splitDim = split_dim; } +void Split::SetSplitDim(int split_dim) { this->primitive_->value.AsSplit()->splitDim = split_dim; } #else -int Split::GetNumberSplit() const { return this->primitive->value_as_Split()->numberSplit(); } +int Split::GetNumberSplit() const { return this->primitive_->value_as_Split()->numberSplit(); } std::vector Split::GetSizeSplits() const { - auto fb_vector = this->primitive->value_as_Split()->sizeSplits(); + auto fb_vector = this->primitive_->value_as_Split()->sizeSplits(); return std::vector(fb_vector->begin(), fb_vector->end()); } -int Split::GetSplitDim() const { return this->primitive->value_as_Split()->splitDim(); } +int Split::GetSplitDim() const { return this->primitive_->value_as_Split()->splitDim(); } void Split::SetNumberSplit(int number_split) {} void Split::SetSizeSplits(const std::vector &size_splits) {} @@ -47,7 +47,7 @@ namespace { constexpr int kSplitInputNum = 1; } // namespace int Split::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); MS_ASSERT(spilt_prim != nullptr); diff --git a/mindspore/lite/src/ops/split.h b/mindspore/lite/src/ops/split.h index c065732574a..433d3259f79 100644 --- a/mindspore/lite/src/ops/split.h +++ b/mindspore/lite/src/ops/split.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_SPLIT_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SPLIT_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_SPLIT_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SPLIT_H_ namespace mindspore { namespace lite { class Split : public PrimitiveC { public: - explicit Split(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Split(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Split(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetNumberSplit() const; diff --git a/mindspore/lite/src/ops/sqrt.h b/mindspore/lite/src/ops/sqrt.h index 2d861251cb1..88ceafa07eb 100644 --- a/mindspore/lite/src/ops/sqrt.h +++ b/mindspore/lite/src/ops/sqrt.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_SQRT_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SQRT_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/arithmetic_self.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_SQRT_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SQRT_H_ +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { class Sqrt : public ArithmeticSelf { public: - explicit Sqrt(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Sqrt(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} +#endif + explicit Sqrt(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/square.h b/mindspore/lite/src/ops/square.h index eb847039c0b..dca72cfa5ee 100644 --- a/mindspore/lite/src/ops/square.h +++ b/mindspore/lite/src/ops/square.h @@ -13,26 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_SQUARE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SQUARE_H_ #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/arithmetic_self.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_SQUARE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SQUARE_H_ +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { class Square : public ArithmeticSelf { public: - explicit Square(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Square(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} +#endif + explicit Square(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/squared_difference.h b/mindspore/lite/src/ops/squared_difference.h index 3f4768fc111..e1d88db06cb 100644 --- a/mindspore/lite/src/ops/squared_difference.h +++ b/mindspore/lite/src/ops/squared_difference.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_SQUARED_DIFFERENCE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SQUARED_DIFFERENCE_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/arithmetic.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_SQUARED_DIFFERENCE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SQUARED_DIFFERENCE_H_ +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { class SquaredDifference : public Arithmetic { public: - explicit SquaredDifference(OriginPrimitive *primitive) : Arithmetic(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit SquaredDifference(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + explicit SquaredDifference(schema::Primitive *primitive) : Arithmetic(primitive) {} }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/squeeze.cc b/mindspore/lite/src/ops/squeeze.cc index 2496c77fd95..684300e38a4 100644 --- a/mindspore/lite/src/ops/squeeze.cc +++ b/mindspore/lite/src/ops/squeeze.cc @@ -19,14 +19,14 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector Squeeze::GetAxis() const { return this->primitive->value.AsSqueeze()->axis; } +std::vector Squeeze::GetAxis() const { return this->primitive_->value.AsSqueeze()->axis; } -void Squeeze::SetAxis(const std::vector &axis) { this->primitive->value.AsSqueeze()->axis = axis; } +void Squeeze::SetAxis(const std::vector &axis) { this->primitive_->value.AsSqueeze()->axis = axis; } #else std::vector Squeeze::GetAxis() const { - auto fb_vector = this->primitive->value_as_Squeeze()->axis(); + auto fb_vector = this->primitive_->value_as_Squeeze()->axis(); return std::vector(fb_vector->begin(), fb_vector->end()); } @@ -38,7 +38,7 @@ constexpr int kSqueezeInputNum = 1; constexpr int kSqueezeOutputNum = 1; } // namespace int Squeeze::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); if (kSqueezeInputNum != inputs_.size()) { MS_LOG(ERROR) << "Add should has " << kSqueezeInputNum << " inputs"; return -1; diff --git a/mindspore/lite/src/ops/squeeze.h b/mindspore/lite/src/ops/squeeze.h index c9b576f8766..7ca9adbcc54 100644 --- a/mindspore/lite/src/ops/squeeze.h +++ b/mindspore/lite/src/ops/squeeze.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_SQUEEZE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SQUEEZE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_SQUEEZE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SQUEEZE_H_ namespace mindspore { namespace lite { class Squeeze : public PrimitiveC { public: - explicit Squeeze(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Squeeze(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Squeeze(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetAxis() const; diff --git a/mindspore/lite/src/ops/stack.cc b/mindspore/lite/src/ops/stack.cc index c87bd86dc7e..17f7c05ac31 100644 --- a/mindspore/lite/src/ops/stack.cc +++ b/mindspore/lite/src/ops/stack.cc @@ -19,20 +19,20 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Stack::GetAxis() const { return this->primitive->value.AsStack()->axis; } -int Stack::GetN() const { return this->primitive->value.AsStack()->n; } -std::vector Stack::GetIsScale() const { return this->primitive->value.AsStack()->isScale; } +int Stack::GetAxis() const { return this->primitive_->value.AsStack()->axis; } +int Stack::GetN() const { return this->primitive_->value.AsStack()->n; } +std::vector Stack::GetIsScale() const { return this->primitive_->value.AsStack()->isScale; } -void Stack::SetAxis(int axis) { this->primitive->value.AsStack()->axis = axis; } -void Stack::SetN(int n) { this->primitive->value.AsStack()->n = n; } -void Stack::SetIsScale(const std::vector &is_scale) { this->primitive->value.AsStack()->isScale = is_scale; } +void Stack::SetAxis(int axis) { this->primitive_->value.AsStack()->axis = axis; } +void Stack::SetN(int n) { this->primitive_->value.AsStack()->n = n; } +void Stack::SetIsScale(const std::vector &is_scale) { this->primitive_->value.AsStack()->isScale = is_scale; } #else -int Stack::GetAxis() const { return this->primitive->value_as_Stack()->axis(); } -int Stack::GetN() const { return this->primitive->value_as_Stack()->n(); } +int Stack::GetAxis() const { return this->primitive_->value_as_Stack()->axis(); } +int Stack::GetN() const { return this->primitive_->value_as_Stack()->n(); } std::vector Stack::GetIsScale() const { - auto fb_vector = this->primitive->value_as_Stack()->isScale(); + auto fb_vector = this->primitive_->value_as_Stack()->isScale(); return std::vector(fb_vector->begin(), fb_vector->end()); } @@ -46,7 +46,7 @@ constexpr int kStackOutputNum = 1; constexpr int kStackMinInputNum = 2; } // namespace int Stack::InferShape(std::vector inputs, std::vector outputs) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); if (outputs.size() != kStackOutputNum) { MS_LOG(ERROR) << "Invalid output size:" << outputs.size(); return RET_PARAM_INVALID; diff --git a/mindspore/lite/src/ops/stack.h b/mindspore/lite/src/ops/stack.h index 14c78a079ef..7f3a2725f40 100644 --- a/mindspore/lite/src/ops/stack.h +++ b/mindspore/lite/src/ops/stack.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_STACK_H_ +#define LITE_MINDSPORE_LITE_C_OPS_STACK_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_STACK_H_ -#define LITE_MINDSPORE_LITE_C_OPS_STACK_H_ namespace mindspore { namespace lite { class Stack : public PrimitiveC { public: - explicit Stack(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Stack(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Stack(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; diff --git a/mindspore/lite/src/ops/strided_slice.cc b/mindspore/lite/src/ops/strided_slice.cc index 80714673a4b..1f2b2e4e4ec 100644 --- a/mindspore/lite/src/ops/strided_slice.cc +++ b/mindspore/lite/src/ops/strided_slice.cc @@ -19,57 +19,57 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int StridedSlice::GetBeginMask() const { return this->primitive->value.AsStridedSlice()->beginMask; } -int StridedSlice::GetEndMask() const { return this->primitive->value.AsStridedSlice()->endMask; } -int StridedSlice::GetEllipsisMask() const { return this->primitive->value.AsStridedSlice()->ellipsisMask; } -int StridedSlice::GetNewAxisMask() const { return this->primitive->value.AsStridedSlice()->newAxisMask; } -int StridedSlice::GetShrinkAxisMask() const { return this->primitive->value.AsStridedSlice()->shrinkAxisMask; } -std::vector StridedSlice::GetBegin() const { return this->primitive->value.AsStridedSlice()->begin; } -std::vector StridedSlice::GetEnd() const { return this->primitive->value.AsStridedSlice()->end; } -std::vector StridedSlice::GetStride() const { return this->primitive->value.AsStridedSlice()->stride; } -std::vector StridedSlice::GetIsScale() const { return this->primitive->value.AsStridedSlice()->isScale; } +int StridedSlice::GetBeginMask() const { return this->primitive_->value.AsStridedSlice()->beginMask; } +int StridedSlice::GetEndMask() const { return this->primitive_->value.AsStridedSlice()->endMask; } +int StridedSlice::GetEllipsisMask() const { return this->primitive_->value.AsStridedSlice()->ellipsisMask; } +int StridedSlice::GetNewAxisMask() const { return this->primitive_->value.AsStridedSlice()->newAxisMask; } +int StridedSlice::GetShrinkAxisMask() const { return this->primitive_->value.AsStridedSlice()->shrinkAxisMask; } +std::vector StridedSlice::GetBegin() const { return this->primitive_->value.AsStridedSlice()->begin; } +std::vector StridedSlice::GetEnd() const { return this->primitive_->value.AsStridedSlice()->end; } +std::vector StridedSlice::GetStride() const { return this->primitive_->value.AsStridedSlice()->stride; } +std::vector StridedSlice::GetIsScale() const { return this->primitive_->value.AsStridedSlice()->isScale; } -void StridedSlice::SetBeginMask(int begin_mask) { this->primitive->value.AsStridedSlice()->beginMask = begin_mask; } -void StridedSlice::SetEndMask(int end_mask) { this->primitive->value.AsStridedSlice()->endMask = end_mask; } +void StridedSlice::SetBeginMask(int begin_mask) { this->primitive_->value.AsStridedSlice()->beginMask = begin_mask; } +void StridedSlice::SetEndMask(int end_mask) { this->primitive_->value.AsStridedSlice()->endMask = end_mask; } void StridedSlice::SetEllipsisMask(int ellipsis_mask) { - this->primitive->value.AsStridedSlice()->ellipsisMask = ellipsis_mask; + this->primitive_->value.AsStridedSlice()->ellipsisMask = ellipsis_mask; } void StridedSlice::SetNewAxisMask(int new_axis_mask) { - this->primitive->value.AsStridedSlice()->newAxisMask = new_axis_mask; + this->primitive_->value.AsStridedSlice()->newAxisMask = new_axis_mask; } void StridedSlice::SetShrinkAxisMask(int shrink_axis_mask) { - this->primitive->value.AsStridedSlice()->shrinkAxisMask = shrink_axis_mask; + this->primitive_->value.AsStridedSlice()->shrinkAxisMask = shrink_axis_mask; } -void StridedSlice::SetBegin(const std::vector &begin) { this->primitive->value.AsStridedSlice()->begin = begin; } -void StridedSlice::SetEnd(const std::vector &end) { this->primitive->value.AsStridedSlice()->end = end; } +void StridedSlice::SetBegin(const std::vector &begin) { this->primitive_->value.AsStridedSlice()->begin = begin; } +void StridedSlice::SetEnd(const std::vector &end) { this->primitive_->value.AsStridedSlice()->end = end; } void StridedSlice::SetStride(const std::vector &stride) { - this->primitive->value.AsStridedSlice()->stride = stride; + this->primitive_->value.AsStridedSlice()->stride = stride; } void StridedSlice::SetIsScale(const std::vector &is_scale) { - this->primitive->value.AsStridedSlice()->isScale = is_scale; + this->primitive_->value.AsStridedSlice()->isScale = is_scale; } #else -int StridedSlice::GetBeginMask() const { return this->primitive->value_as_StridedSlice()->beginMask(); } -int StridedSlice::GetEndMask() const { return this->primitive->value_as_StridedSlice()->endMask(); } -int StridedSlice::GetEllipsisMask() const { return this->primitive->value_as_StridedSlice()->ellipsisMask(); } -int StridedSlice::GetNewAxisMask() const { return this->primitive->value_as_StridedSlice()->newAxisMask(); } -int StridedSlice::GetShrinkAxisMask() const { return this->primitive->value_as_StridedSlice()->shrinkAxisMask(); } +int StridedSlice::GetBeginMask() const { return this->primitive_->value_as_StridedSlice()->beginMask(); } +int StridedSlice::GetEndMask() const { return this->primitive_->value_as_StridedSlice()->endMask(); } +int StridedSlice::GetEllipsisMask() const { return this->primitive_->value_as_StridedSlice()->ellipsisMask(); } +int StridedSlice::GetNewAxisMask() const { return this->primitive_->value_as_StridedSlice()->newAxisMask(); } +int StridedSlice::GetShrinkAxisMask() const { return this->primitive_->value_as_StridedSlice()->shrinkAxisMask(); } std::vector StridedSlice::GetBegin() const { - auto fb_vector = this->primitive->value_as_StridedSlice()->begin(); + auto fb_vector = this->primitive_->value_as_StridedSlice()->begin(); return std::vector(fb_vector->begin(), fb_vector->end()); } std::vector StridedSlice::GetEnd() const { - auto fb_vector = this->primitive->value_as_StridedSlice()->end(); + auto fb_vector = this->primitive_->value_as_StridedSlice()->end(); return std::vector(fb_vector->begin(), fb_vector->end()); } std::vector StridedSlice::GetStride() const { - auto fb_vector = this->primitive->value_as_StridedSlice()->stride(); + auto fb_vector = this->primitive_->value_as_StridedSlice()->stride(); return std::vector(fb_vector->begin(), fb_vector->end()); } std::vector StridedSlice::GetIsScale() const { - auto fb_vector = this->primitive->value_as_StridedSlice()->isScale(); + auto fb_vector = this->primitive_->value_as_StridedSlice()->isScale(); return std::vector(fb_vector->begin(), fb_vector->end()); } @@ -154,7 +154,7 @@ void StridedSlice::ApplyEndMask() { } int StridedSlice::InferShape(std::vector inputs, std::vector outputs) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); if (outputs.size() != kStridedSliceOutputNum) { MS_LOG(ERROR) << "Invalid output size:" << outputs.size(); return RET_PARAM_INVALID; diff --git a/mindspore/lite/src/ops/strided_slice.h b/mindspore/lite/src/ops/strided_slice.h index bc09eafc1f3..66df8b29e35 100644 --- a/mindspore/lite/src/ops/strided_slice.h +++ b/mindspore/lite/src/ops/strided_slice.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_STRIDED_SLICE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_STRIDED_SLICE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_STRIDED_SLICE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_STRIDED_SLICE_H_ namespace mindspore { namespace lite { class StridedSlice : public PrimitiveC { public: - explicit StridedSlice(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit StridedSlice(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit StridedSlice(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetBeginMask() const; diff --git a/mindspore/lite/src/ops/sub.cc b/mindspore/lite/src/ops/sub.cc index 6ad68eb2b44..45b188a2cf1 100644 --- a/mindspore/lite/src/ops/sub.cc +++ b/mindspore/lite/src/ops/sub.cc @@ -19,15 +19,15 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Sub::GetActivationType() const { return this->primitive->value.AsSub()->activationType; } +int Sub::GetActivationType() const { return this->primitive_->value.AsSub()->activationType; } void Sub::SetActivationType(int activation_type) { - this->primitive->value.AsSub()->activationType = (schema::ActivationType)activation_type; + this->primitive_->value.AsSub()->activationType = (schema::ActivationType)activation_type; } #else -int Sub::GetActivationType() const { return this->primitive->value_as_Sub()->activationType(); } +int Sub::GetActivationType() const { return this->primitive_->value_as_Sub()->activationType(); } void Sub::SetActivationType(int activation_type) {} #endif diff --git a/mindspore/lite/src/ops/sub.h b/mindspore/lite/src/ops/sub.h index 2724a83356a..2738d183978 100644 --- a/mindspore/lite/src/ops/sub.h +++ b/mindspore/lite/src/ops/sub.h @@ -14,27 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_SUB_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SUB_H_ + #include #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" #include "src/ops/arithmetic.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_SUB_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SUB_H_ - namespace mindspore { namespace lite { class Sub : public Arithmetic { public: - explicit Sub(OriginPrimitive *primitive) : Arithmetic(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Sub(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} +#endif + explicit Sub(schema::Primitive *primitive) : Arithmetic(primitive) {} int GetActivationType() const; void SetActivationType(int activation_type); diff --git a/mindspore/lite/src/ops/tile.cc b/mindspore/lite/src/ops/tile.cc index 87bd725554e..7714a30106b 100644 --- a/mindspore/lite/src/ops/tile.cc +++ b/mindspore/lite/src/ops/tile.cc @@ -20,25 +20,25 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector Tile::GetMultiples() const { return this->primitive->value.AsTile()->multiples; } +std::vector Tile::GetMultiples() const { return this->primitive_->value.AsTile()->multiples; } -void Tile::SetMultiples(const std::vector &multiples) { this->primitive->value.AsTile()->multiples = multiples; } +void Tile::SetMultiples(const std::vector &multiples) { this->primitive_->value.AsTile()->multiples = multiples; } -std::vector Tile::GetDims() const { return this->primitive->value.AsTile()->multiples; } +std::vector Tile::GetDims() const { return this->primitive_->value.AsTile()->multiples; } -void Tile::SetDims(const std::vector &dims) { this->primitive->value.AsTile()->dims = dims; } +void Tile::SetDims(const std::vector &dims) { this->primitive_->value.AsTile()->dims = dims; } #else std::vector Tile::GetMultiples() const { - auto fb_vector = this->primitive->value_as_Tile()->multiples(); + auto fb_vector = this->primitive_->value_as_Tile()->multiples(); return std::vector(fb_vector->begin(), fb_vector->end()); } void Tile::SetMultiples(const std::vector &multiples) {} std::vector Tile::GetDims() const { - auto fb_vector = this->primitive->value_as_Tile()->dims(); + auto fb_vector = this->primitive_->value_as_Tile()->dims(); return std::vector(fb_vector->begin(), fb_vector->end()); } @@ -46,7 +46,7 @@ void Tile::SetDims(const std::vector &dims) {} #endif int Tile::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); diff --git a/mindspore/lite/src/ops/tile.h b/mindspore/lite/src/ops/tile.h index 720b8baabd6..dfc025e5af0 100644 --- a/mindspore/lite/src/ops/tile.h +++ b/mindspore/lite/src/ops/tile.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_TILE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_TILE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_TILE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_TILE_H_ namespace mindspore { namespace lite { class Tile : public PrimitiveC { public: - explicit Tile(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Tile(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Tile(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetMultiples() const; diff --git a/mindspore/lite/src/ops/topk.cc b/mindspore/lite/src/ops/topk.cc index 0f3abac581f..38cb1e89ff8 100644 --- a/mindspore/lite/src/ops/topk.cc +++ b/mindspore/lite/src/ops/topk.cc @@ -19,23 +19,23 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int TopK::GetK() const { return this->primitive->value.AsTopK()->k; } -bool TopK::GetSorted() const { return this->primitive->value.AsTopK()->sorted; } +int TopK::GetK() const { return this->primitive_->value.AsTopK()->k; } +bool TopK::GetSorted() const { return this->primitive_->value.AsTopK()->sorted; } -void TopK::SetK(int k) { this->primitive->value.AsTopK()->k = k; } -void TopK::SetSorted(bool sorted) { this->primitive->value.AsTopK()->sorted = sorted; } +void TopK::SetK(int k) { this->primitive_->value.AsTopK()->k = k; } +void TopK::SetSorted(bool sorted) { this->primitive_->value.AsTopK()->sorted = sorted; } #else -int TopK::GetK() const { return this->primitive->value_as_TopK()->k(); } -bool TopK::GetSorted() const { return this->primitive->value_as_TopK()->sorted(); } +int TopK::GetK() const { return this->primitive_->value_as_TopK()->k(); } +bool TopK::GetSorted() const { return this->primitive_->value_as_TopK()->sorted(); } void TopK::SetK(int k) {} void TopK::SetSorted(bool sorted) {} #endif int TopK::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); if (inputs_.size() != kSingleNum || outputs_.size() != kDoubleNum) { MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); return RET_INPUT_TENSOR_ERROR; diff --git a/mindspore/lite/src/ops/topk.h b/mindspore/lite/src/ops/topk.h index 3b7ec46cced..7b58ecaab2f 100644 --- a/mindspore/lite/src/ops/topk.h +++ b/mindspore/lite/src/ops/topk.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_TOP_K_H_ +#define LITE_MINDSPORE_LITE_C_OPS_TOP_K_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_TOP_K_H_ -#define LITE_MINDSPORE_LITE_C_OPS_TOP_K_H_ namespace mindspore { namespace lite { class TopK : public PrimitiveC { public: - explicit TopK(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit TopK(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit TopK(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetK() const; diff --git a/mindspore/lite/src/ops/transpose.cc b/mindspore/lite/src/ops/transpose.cc index d3666e57ea1..ef6750015f2 100644 --- a/mindspore/lite/src/ops/transpose.cc +++ b/mindspore/lite/src/ops/transpose.cc @@ -21,26 +21,26 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector Transpose::GetPerm() const { return this->primitive->value.AsTranspose()->perm; } -bool Transpose::GetConjugate() const { return this->primitive->value.AsTranspose()->conjugate; } +std::vector Transpose::GetPerm() const { return this->primitive_->value.AsTranspose()->perm; } +bool Transpose::GetConjugate() const { return this->primitive_->value.AsTranspose()->conjugate; } -void Transpose::SetPerm(const std::vector &perm) { this->primitive->value.AsTranspose()->perm = perm; } -void Transpose::SetConjugate(bool conjugate) { this->primitive->value.AsTranspose()->conjugate = conjugate; } +void Transpose::SetPerm(const std::vector &perm) { this->primitive_->value.AsTranspose()->perm = perm; } +void Transpose::SetConjugate(bool conjugate) { this->primitive_->value.AsTranspose()->conjugate = conjugate; } #else std::vector Transpose::GetPerm() const { - auto fb_vector = this->primitive->value_as_Transpose()->perm(); + auto fb_vector = this->primitive_->value_as_Transpose()->perm(); return std::vector(fb_vector->begin(), fb_vector->end()); } -bool Transpose::GetConjugate() const { return this->primitive->value_as_Transpose()->conjugate(); } +bool Transpose::GetConjugate() const { return this->primitive_->value_as_Transpose()->conjugate(); } void Transpose::SetPerm(const std::vector &perm) {} void Transpose::SetConjugate(bool conjugate) {} #endif int Transpose::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); diff --git a/mindspore/lite/src/ops/transpose.h b/mindspore/lite/src/ops/transpose.h index 6092c9d2da7..be11939bec2 100644 --- a/mindspore/lite/src/ops/transpose.h +++ b/mindspore/lite/src/ops/transpose.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_TRANSPOSE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_TRANSPOSE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_TRANSPOSE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_TRANSPOSE_H_ namespace mindspore { namespace lite { class Transpose : public PrimitiveC { public: - explicit Transpose(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Transpose(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Transpose(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetPerm() const; diff --git a/mindspore/lite/src/ops/unique.cc b/mindspore/lite/src/ops/unique.cc index 56b8a3b2836..652114cafab 100644 --- a/mindspore/lite/src/ops/unique.cc +++ b/mindspore/lite/src/ops/unique.cc @@ -19,19 +19,19 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Unique::GetOutType() const { return this->primitive->value.AsUnique()->outType; } +int Unique::GetOutType() const { return this->primitive_->value.AsUnique()->outType; } -void Unique::SetOutType(int out_type) { this->primitive->value.AsUnique()->outType = out_type; } +void Unique::SetOutType(int out_type) { this->primitive_->value.AsUnique()->outType = out_type; } #else -int Unique::GetOutType() const { return this->primitive->value_as_Unique()->outType(); } +int Unique::GetOutType() const { return this->primitive_->value_as_Unique()->outType(); } void Unique::SetOutType(int out_type) {} #endif int Unique::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); if (inputs_.size() != kSingleNum || outputs_.size() != kDoubleNum) { MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); return RET_INPUT_TENSOR_ERROR; diff --git a/mindspore/lite/src/ops/unique.h b/mindspore/lite/src/ops/unique.h index 3091d967187..c623ab89e97 100644 --- a/mindspore/lite/src/ops/unique.h +++ b/mindspore/lite/src/ops/unique.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_UNIQUE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_UNIQUE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_UNIQUE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_UNIQUE_H_ namespace mindspore { namespace lite { class Unique : public PrimitiveC { public: - explicit Unique(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Unique(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Unique(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetOutType() const; diff --git a/mindspore/lite/src/ops/unsqueeze.cc b/mindspore/lite/src/ops/unsqueeze.cc index 1d7f682f4c4..f515bc9567e 100644 --- a/mindspore/lite/src/ops/unsqueeze.cc +++ b/mindspore/lite/src/ops/unsqueeze.cc @@ -22,14 +22,14 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector Unsqueeze::GetAxis() const { return this->primitive->value.AsUnsqueeze()->axis; } +std::vector Unsqueeze::GetAxis() const { return this->primitive_->value.AsUnsqueeze()->axis; } -void Unsqueeze::SetAxis(const std::vector &axis) { this->primitive->value.AsUnsqueeze()->axis = axis; } +void Unsqueeze::SetAxis(const std::vector &axis) { this->primitive_->value.AsUnsqueeze()->axis = axis; } #else bool predicate(int n) { return n != 1; } std::vector Unsqueeze::GetAxis() const { - auto fb_vector = this->primitive->value_as_Unsqueeze()->axis(); + auto fb_vector = this->primitive_->value_as_Unsqueeze()->axis(); return std::vector(fb_vector->begin(), fb_vector->end()); } @@ -37,7 +37,7 @@ void Unsqueeze::SetAxis(const std::vector &axis) {} #endif int Unsqueeze::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); diff --git a/mindspore/lite/src/ops/unsqueeze.h b/mindspore/lite/src/ops/unsqueeze.h index 98c02165f98..1873feaa678 100644 --- a/mindspore/lite/src/ops/unsqueeze.h +++ b/mindspore/lite/src/ops/unsqueeze.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_UNSQUEEZE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_UNSQUEEZE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_UNSQUEEZE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_UNSQUEEZE_H_ namespace mindspore { namespace lite { class Unsqueeze : public PrimitiveC { public: - explicit Unsqueeze(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Unsqueeze(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Unsqueeze(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetAxis() const; diff --git a/mindspore/lite/src/ops/unstack.cc b/mindspore/lite/src/ops/unstack.cc index 7119f180190..6d3a6ff03d2 100644 --- a/mindspore/lite/src/ops/unstack.cc +++ b/mindspore/lite/src/ops/unstack.cc @@ -19,16 +19,16 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Unstack::GetNum() const { return this->primitive->value.AsUnstack()->num; } -int Unstack::GetAxis() const { return this->primitive->value.AsUnstack()->axis; } +int Unstack::GetNum() const { return this->primitive_->value.AsUnstack()->num; } +int Unstack::GetAxis() const { return this->primitive_->value.AsUnstack()->axis; } -void Unstack::SetNum(int num) { this->primitive->value.AsUnstack()->num = num; } -void Unstack::SetAxis(int axis) { this->primitive->value.AsUnstack()->axis = axis; } +void Unstack::SetNum(int num) { this->primitive_->value.AsUnstack()->num = num; } +void Unstack::SetAxis(int axis) { this->primitive_->value.AsUnstack()->axis = axis; } #else -int Unstack::GetNum() const { return this->primitive->value_as_Unstack()->num(); } -int Unstack::GetAxis() const { return this->primitive->value_as_Unstack()->axis(); } +int Unstack::GetNum() const { return this->primitive_->value_as_Unstack()->num(); } +int Unstack::GetAxis() const { return this->primitive_->value_as_Unstack()->axis(); } void Unstack::SetNum(int num) {} void Unstack::SetAxis(int axis) {} diff --git a/mindspore/lite/src/ops/unstack.h b/mindspore/lite/src/ops/unstack.h index d5ab96b4f4f..8c7b357e5dd 100644 --- a/mindspore/lite/src/ops/unstack.h +++ b/mindspore/lite/src/ops/unstack.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_UNSTACK_H_ +#define LITE_MINDSPORE_LITE_C_OPS_UNSTACK_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_UNSTACK_H_ -#define LITE_MINDSPORE_LITE_C_OPS_UNSTACK_H_ namespace mindspore { namespace lite { class Unstack : public PrimitiveC { public: - explicit Unstack(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Unstack(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Unstack(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; int GetNum() const; diff --git a/mindspore/lite/src/ops/upsample.cc b/mindspore/lite/src/ops/upsample.cc index 96941e22c69..9f0623bc294 100644 --- a/mindspore/lite/src/ops/upsample.cc +++ b/mindspore/lite/src/ops/upsample.cc @@ -15,21 +15,22 @@ */ #include "src/ops/upsample.h" +#include namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::string Upsample::GetMode() const { return this->primitive->value.AsUpsample()->mode; } -std::vector Upsample::GetScales() const { return this->primitive->value.AsUpsample()->scales; } +std::string Upsample::GetMode() const { return this->primitive_->value.AsUpsample()->mode; } +std::vector Upsample::GetScales() const { return this->primitive_->value.AsUpsample()->scales; } -void Upsample::SetMode(std::string mode) { this->primitive->value.AsUpsample()->mode = mode; } -void Upsample::SetScales(const std::vector &scales) { this->primitive->value.AsUpsample()->scales = scales; } +void Upsample::SetMode(std::string mode) { this->primitive_->value.AsUpsample()->mode = mode; } +void Upsample::SetScales(const std::vector &scales) { this->primitive_->value.AsUpsample()->scales = scales; } #else -std::string Upsample::GetMode() const { return this->primitive->value_as_Upsample()->mode()->str(); } +std::string Upsample::GetMode() const { return this->primitive_->value_as_Upsample()->mode()->str(); } std::vector Upsample::GetScales() const { - auto fb_vector = this->primitive->value_as_Upsample()->scales(); + auto fb_vector = this->primitive_->value_as_Upsample()->scales(); return std::vector(fb_vector->begin(), fb_vector->end()); } diff --git a/mindspore/lite/src/ops/upsample.h b/mindspore/lite/src/ops/upsample.h index ced71016c56..61f0869ae74 100644 --- a/mindspore/lite/src/ops/upsample.h +++ b/mindspore/lite/src/ops/upsample.h @@ -14,26 +14,24 @@ * limitations under the License. */ -#include +#ifndef LITE_MINDSPORE_LITE_C_OPS_UPSAMPLE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_UPSAMPLE_H_ + #include #include #include +#include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_UPSAMPLE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_UPSAMPLE_H_ namespace mindspore { namespace lite { class Upsample : public PrimitiveC { public: - explicit Upsample(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Upsample(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Upsample(schema::Primitive *primitive) : PrimitiveC(primitive) {} std::string GetMode() const; std::vector GetScales() const; diff --git a/mindspore/lite/src/ops/where.cc b/mindspore/lite/src/ops/where.cc index 35bc3056ddb..1641fd5e2c0 100644 --- a/mindspore/lite/src/ops/where.cc +++ b/mindspore/lite/src/ops/where.cc @@ -19,16 +19,16 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -std::vector Where::GetCondition() const { return this->primitive->value.AsWhere()->condition; } +std::vector Where::GetCondition() const { return this->primitive_->value.AsWhere()->condition; } void Where::SetCondition(const std::vector &condition) { - this->primitive->value.AsWhere()->condition = condition; + this->primitive_->value.AsWhere()->condition = condition; } #else std::vector Where::GetCondition() const { - auto fb_vector = this->primitive->value_as_Where()->condition(); + auto fb_vector = this->primitive_->value_as_Where()->condition(); return std::vector(fb_vector->begin(), fb_vector->end()); } @@ -36,7 +36,7 @@ void Where::SetCondition(const std::vector &condition) {} #endif int Where::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); diff --git a/mindspore/lite/src/ops/where.h b/mindspore/lite/src/ops/where.h index ef9a9f179e3..9279a147b86 100644 --- a/mindspore/lite/src/ops/where.h +++ b/mindspore/lite/src/ops/where.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_WHERE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_WHERE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_WHERE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_WHERE_H_ namespace mindspore { namespace lite { class Where : public PrimitiveC { public: - explicit Where(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit Where(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit Where(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetCondition() const; diff --git a/mindspore/lite/src/ops/zeros_like.cc b/mindspore/lite/src/ops/zeros_like.cc index eaa87c1cf37..23e674617d6 100644 --- a/mindspore/lite/src/ops/zeros_like.cc +++ b/mindspore/lite/src/ops/zeros_like.cc @@ -19,7 +19,7 @@ namespace mindspore { namespace lite { int ZerosLike::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive != nullptr); + MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); auto output = outputs_.front(); diff --git a/mindspore/lite/src/ops/zeros_like.h b/mindspore/lite/src/ops/zeros_like.h index bc323972af3..a509fa1b021 100644 --- a/mindspore/lite/src/ops/zeros_like.h +++ b/mindspore/lite/src/ops/zeros_like.h @@ -14,25 +14,23 @@ * limitations under the License. */ +#ifndef LITE_MINDSPORE_LITE_C_OPS_ZEROS_LIKE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_ZEROS_LIKE_H_ + #include #include #include #include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" -#ifdef PRIMITIVE_WRITEABLE -#include "schema/inner/model_generated.h" -#else -#include "schema/model_generated.h" -#endif - -#ifndef LITE_MINDSPORE_LITE_C_OPS_ZEROS_LIKE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_ZEROS_LIKE_H_ namespace mindspore { namespace lite { class ZerosLike : public PrimitiveC { public: - explicit ZerosLike(OriginPrimitive *primitive) : PrimitiveC(primitive) {} +#ifdef PRIMITIVE_WRITEABLE + explicit ZerosLike(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#endif + explicit ZerosLike(schema::Primitive *primitive) : PrimitiveC(primitive) {} int InferShape(std::vector inputs_, std::vector outputs_) override; }; diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 515d18824ec..038e7532348 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -170,13 +170,14 @@ namespace mindspore::kernel { OpParameter *PopulateROIPoolingParameter(const mindspore::lite::PrimitiveC *primitive) { - const auto param = dynamic_cast(primitive); + const auto param = + reinterpret_cast(const_cast(primitive)); auto *roi_pooling_param = new (std::nothrow) ROIPoolingParameter(); if (param == nullptr) { MS_LOG(ERROR) << "new PoolingParameter failed."; return nullptr; } - roi_pooling_param->op_parameter_.type_ = param->Type(); + roi_pooling_param->op_parameter_.type_ = primitive->Type(); roi_pooling_param->pooledH_ = param->GetPooledW(); roi_pooling_param->pooledW_ = param->GetPooledW(); roi_pooling_param->scale_ = param->GetScale(); @@ -184,7 +185,8 @@ OpParameter *PopulateROIPoolingParameter(const mindspore::lite::PrimitiveC *prim } OpParameter *PopulateBatchNorm(const mindspore::lite::PrimitiveC *primitive) { - const auto param = dynamic_cast(primitive); + const auto param = + reinterpret_cast(const_cast(primitive)); auto *batch_norm_param = new (std::nothrow) BatchNormParameter(); if (batch_norm_param == nullptr) { MS_LOG(ERROR) << "new BatchNormParameter failed."; @@ -197,7 +199,7 @@ OpParameter *PopulateBatchNorm(const mindspore::lite::PrimitiveC *primitive) { } OpParameter *PopulateFillParameter(const mindspore::lite::PrimitiveC *primitive) { - const auto param = dynamic_cast(primitive); + const auto param = reinterpret_cast(const_cast(primitive)); auto *fill_param = new (std::nothrow) FillParameter(); if (fill_param == nullptr) { MS_LOG(ERROR) << "new FillParameter failed."; @@ -214,7 +216,7 @@ OpParameter *PopulateFillParameter(const mindspore::lite::PrimitiveC *primitive) } OpParameter *PopulateExpandDimsParameter(const mindspore::lite::PrimitiveC *primitive) { - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); auto *expand_dims_param = new (std::nothrow) ExpandDimsParameter(); if (expand_dims_param == nullptr) { MS_LOG(ERROR) << "new ExpandDimsParameter failed."; @@ -226,7 +228,7 @@ OpParameter *PopulateExpandDimsParameter(const mindspore::lite::PrimitiveC *prim } OpParameter *PopulateCaffePReLUParameter(const mindspore::lite::PrimitiveC *primitive) { - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); auto *caffePrelu_param = new (std::nothrow) CaffePreluParameter(); if (caffePrelu_param == nullptr) { MS_LOG(ERROR) << "new caffePReluParameter failed."; @@ -238,7 +240,7 @@ OpParameter *PopulateCaffePReLUParameter(const mindspore::lite::PrimitiveC *prim } OpParameter *PopulatePreluParameter(const mindspore::lite::PrimitiveC *primitive) { - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); auto *prelu_param = new (std::nothrow) PreluParameter(); if (prelu_param == nullptr) { MS_LOG(ERROR) << "new caffePReluParameter failed."; @@ -253,7 +255,8 @@ OpParameter *PopulatePreluParameter(const mindspore::lite::PrimitiveC *primitive } OpParameter *PopulatePoolingParameter(const mindspore::lite::PrimitiveC *primitive) { - auto pooling_primitive = dynamic_cast(primitive); + auto pooling_primitive = + reinterpret_cast(const_cast(primitive)); auto *pooling_param = new (std::nothrow) PoolingParameter(); if (pooling_param == nullptr) { MS_LOG(ERROR) << "new PoolingParameter failed."; @@ -309,7 +312,8 @@ OpParameter *PopulatePoolingParameter(const mindspore::lite::PrimitiveC *primiti } OpParameter *PopulateFullconnectionParameter(const mindspore::lite::PrimitiveC *primitive) { - auto param = dynamic_cast(primitive); + auto param = + reinterpret_cast(const_cast(primitive)); auto *matmul_param = new (std::nothrow) MatMulParameter(); if (matmul_param == nullptr) { MS_LOG(ERROR) << "new FullconnectionParameter failed."; @@ -331,7 +335,7 @@ OpParameter *PopulateFullconnectionParameter(const mindspore::lite::PrimitiveC * } OpParameter *PopulateMatMulParameter(const mindspore::lite::PrimitiveC *primitive) { - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); auto *matmul_param = new (std::nothrow) MatMulParameter(); if (matmul_param == nullptr) { MS_LOG(ERROR) << "new FullconnectionParameter failed."; @@ -352,7 +356,8 @@ OpParameter *PopulateConvParameter(const mindspore::lite::PrimitiveC *primitive) return nullptr; } conv_param->op_parameter_.type_ = primitive->Type(); - auto conv_primitive = dynamic_cast(primitive); + auto conv_primitive = + reinterpret_cast(const_cast(primitive)); conv_param->kernel_h_ = conv_primitive->GetKernelH(); conv_param->kernel_w_ = conv_primitive->GetKernelW(); conv_param->group_ = conv_primitive->GetGroup(); @@ -398,7 +403,8 @@ OpParameter *PopulateConvDwParameter(const mindspore::lite::PrimitiveC *primitiv } conv_param->op_parameter_.type_ = primitive->Type(); - auto conv_primitive = dynamic_cast(primitive); + auto conv_primitive = + reinterpret_cast(const_cast(primitive)); conv_param->kernel_h_ = conv_primitive->GetKernelH(); conv_param->kernel_w_ = conv_primitive->GetKernelW(); conv_param->stride_h_ = conv_primitive->GetStrideH(); @@ -439,7 +445,8 @@ OpParameter *PopulateDeconvDwParameter(const mindspore::lite::PrimitiveC *primit return nullptr; } conv_param->op_parameter_.type_ = primitive->Type(); - auto conv_primitive = dynamic_cast(primitive); + auto conv_primitive = + reinterpret_cast(const_cast(primitive)); conv_param->kernel_h_ = conv_primitive->GetKernelH(); conv_param->kernel_w_ = conv_primitive->GetKernelW(); conv_param->stride_h_ = conv_primitive->GetStrideH(); @@ -480,7 +487,8 @@ OpParameter *PopulateDeconvParameter(const mindspore::lite::PrimitiveC *primitiv return nullptr; } conv_param->op_parameter_.type_ = primitive->Type(); - auto conv_primitive = dynamic_cast(primitive); + auto conv_primitive = + reinterpret_cast(const_cast(primitive)); conv_param->kernel_h_ = conv_primitive->GetKernelH(); conv_param->kernel_w_ = conv_primitive->GetKernelW(); conv_param->stride_h_ = conv_primitive->GetStrideH(); @@ -533,7 +541,8 @@ OpParameter *PopulateDeconvParameter(const mindspore::lite::PrimitiveC *primitiv } OpParameter *PopulateSoftmaxParameter(const mindspore::lite::PrimitiveC *primitive) { - auto softmax_primitive = dynamic_cast(primitive); + auto softmax_primitive = + reinterpret_cast(const_cast(primitive)); auto *softmax_param = new (std::nothrow) SoftmaxParameter(); if (softmax_param == nullptr) { MS_LOG(ERROR) << "new SoftmaxParameter failed."; @@ -551,7 +560,7 @@ OpParameter *PopulateReduceParameter(const mindspore::lite::PrimitiveC *primitiv return nullptr; } reduce_param->op_parameter_.type_ = primitive->Type(); - auto reduce = dynamic_cast(primitive); + auto reduce = reinterpret_cast(const_cast(primitive)); reduce_param->keep_dims_ = reduce->GetKeepDims(); auto axisVector = reduce->GetAxes(); if (axisVector.size() > REDUCE_MAX_AXES_NUM) { @@ -575,7 +584,7 @@ OpParameter *PopulateMeanParameter(const mindspore::lite::PrimitiveC *primitive) return nullptr; } mean_param->op_parameter_.type_ = primitive->Type(); - auto mean = dynamic_cast(primitive); + auto mean = reinterpret_cast(const_cast(primitive)); mean_param->keep_dims_ = mean->GetKeepDims(); auto axisVector = mean->GetAxis(); if (axisVector.size() > REDUCE_MAX_AXES_NUM) { @@ -599,7 +608,7 @@ OpParameter *PopulatePadParameter(const mindspore::lite::PrimitiveC *primitive) return nullptr; } pad_param->op_parameter_.type_ = primitive->Type(); - auto pad_node = dynamic_cast(primitive); + auto pad_node = reinterpret_cast(const_cast(primitive)); pad_param->pad_mode_ = pad_node->GetPaddingMode(); if (pad_param->pad_mode_ == schema::PaddingMode_CONSTANT) { pad_param->constant_value_ = pad_node->GetConstantValue(); @@ -628,7 +637,8 @@ OpParameter *PopulateActivationParameter(const mindspore::lite::PrimitiveC *prim MS_LOG(ERROR) << "new ActivationParameter failed."; return nullptr; } - auto activation = dynamic_cast(primitive); + auto activation = + reinterpret_cast(const_cast(primitive)); act_param->type_ = static_cast(activation->GetType()); act_param->alpha_ = activation->GetAlpha(); return reinterpret_cast(act_param); @@ -641,7 +651,8 @@ OpParameter *PopulateFusedBatchNorm(const mindspore::lite::PrimitiveC *primitive return nullptr; } batch_norm_param->op_parameter_.type_ = primitive->Type(); - auto param = dynamic_cast(primitive); + auto param = + reinterpret_cast(const_cast(primitive)); batch_norm_param->epsilon_ = param->GetEpsilon(); batch_norm_param->fused_ = true; return reinterpret_cast(batch_norm_param); @@ -658,16 +669,24 @@ OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive) { arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); switch (primitive->Type()) { case schema::PrimitiveType_Add: - arithmetic_param->activation_type_ = dynamic_cast(primitive)->GetActivationType(); + arithmetic_param->activation_type_ = + reinterpret_cast(const_cast(primitive)) + ->GetActivationType(); break; case schema::PrimitiveType_Sub: - arithmetic_param->activation_type_ = dynamic_cast(primitive)->GetActivationType(); + arithmetic_param->activation_type_ = + reinterpret_cast(const_cast(primitive)) + ->GetActivationType(); break; case schema::PrimitiveType_Mul: - arithmetic_param->activation_type_ = dynamic_cast(primitive)->GetActivationType(); + arithmetic_param->activation_type_ = + reinterpret_cast(const_cast(primitive)) + ->GetActivationType(); break; case schema::PrimitiveType_Div: - arithmetic_param->activation_type_ = dynamic_cast(primitive)->GetActivationType(); + arithmetic_param->activation_type_ = + reinterpret_cast(const_cast(primitive)) + ->GetActivationType(); break; default: arithmetic_param->activation_type_ = 0; @@ -688,7 +707,7 @@ OpParameter *PopulateEltwiseParameter(const mindspore::lite::PrimitiveC *primiti MS_LOG(ERROR) << "new ArithmeticParameter failed."; return nullptr; } - auto eltwise = dynamic_cast(primitive); + auto eltwise = reinterpret_cast(const_cast(primitive)); switch (eltwise->GetMode()) { case schema::EltwiseMode_PROD: arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Mul; @@ -723,7 +742,7 @@ OpParameter *PopulatePowerParameter(const mindspore::lite::PrimitiveC *primitive return nullptr; } power_param->op_parameter_.type_ = primitive->Type(); - auto power = dynamic_cast(primitive); + auto power = reinterpret_cast(const_cast(primitive)); power_param->power_ = power->GetPower(); power_param->scale_ = power->GetScale(); power_param->shift_ = power->GetShift(); @@ -737,7 +756,7 @@ OpParameter *PopulateArgMaxParameter(const mindspore::lite::PrimitiveC *primitiv return nullptr; } arg_param->op_parameter_.type_ = primitive->Type(); - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); arg_param->axis_ = param->GetAxis(); arg_param->topk_ = param->GetTopK(); arg_param->axis_type_ = param->GetAxisType(); @@ -753,7 +772,7 @@ OpParameter *PopulateArgMinParameter(const mindspore::lite::PrimitiveC *primitiv return nullptr; } arg_param->op_parameter_.type_ = primitive->Type(); - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); arg_param->axis_ = param->GetAxis(); arg_param->topk_ = param->GetTopK(); arg_param->axis_type_ = param->GetAxisType(); @@ -769,14 +788,15 @@ OpParameter *PopulateCastParameter(const mindspore::lite::PrimitiveC *primitive) return nullptr; } cast_param->op_parameter_.type_ = primitive->Type(); - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); cast_param->src_type_ = param->GetSrcT(); cast_param->dst_type_ = param->GetDstT(); return reinterpret_cast(cast_param); } OpParameter *PopulateLocalResponseNormParameter(const mindspore::lite::PrimitiveC *primitive) { - auto local_response_norm_attr = dynamic_cast(primitive); + auto local_response_norm_attr = reinterpret_cast( + const_cast(primitive)); auto *lrn_param = new (std::nothrow) LocalResponseNormParameter(); if (lrn_param == nullptr) { MS_LOG(ERROR) << "new LocalResponseNormParameter failed."; @@ -791,7 +811,7 @@ OpParameter *PopulateLocalResponseNormParameter(const mindspore::lite::Primitive } OpParameter *PopulateRangeParameter(const mindspore::lite::PrimitiveC *primitive) { - auto range_attr = dynamic_cast(primitive); + auto range_attr = reinterpret_cast(const_cast(primitive)); auto *range_param = new (std::nothrow) RangeParameter(); if (range_param == nullptr) { MS_LOG(ERROR) << "new RangeParameter failed."; @@ -812,7 +832,7 @@ OpParameter *PopulateConcatParameter(const mindspore::lite::PrimitiveC *primitiv return nullptr; } concat_param->op_parameter_.type_ = primitive->Type(); - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); concat_param->axis_ = param->GetAxis(); return reinterpret_cast(concat_param); } @@ -824,7 +844,7 @@ OpParameter *PopulateTileParameter(const mindspore::lite::PrimitiveC *primitive) return nullptr; } tile_param->op_parameter_.type_ = primitive->Type(); - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); auto multiples = param->GetMultiples(); tile_param->in_dim_ = multiples.size(); for (int i = 0; i < tile_param->in_dim_; ++i) { @@ -840,7 +860,7 @@ OpParameter *PopulateTopKParameter(const mindspore::lite::PrimitiveC *primitive) return nullptr; } topk_param->op_parameter_.type_ = primitive->Type(); - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); topk_param->k_ = param->GetK(); topk_param->sorted_ = param->GetSorted(); return reinterpret_cast(topk_param); @@ -872,7 +892,7 @@ OpParameter *PopulateTransposeParameter(const mindspore::lite::PrimitiveC *primi MS_LOG(ERROR) << "new TransposeParameter failed."; return nullptr; } - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); transpose_param->op_parameter_.type_ = primitive->Type(); auto perm_vector_ = param->GetPerm(); int i = 0; @@ -890,7 +910,7 @@ OpParameter *PopulateSplitParameter(const mindspore::lite::PrimitiveC *primitive MS_LOG(ERROR) << "new SplitParameter failed."; return nullptr; } - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); split_param->op_parameter_.type_ = primitive->Type(); split_param->num_split_ = param->GetNumberSplit(); auto split_sizes_vector_ = param->GetSizeSplits(); @@ -924,17 +944,13 @@ OpParameter *PopulateScaleParameter(const mindspore::lite::PrimitiveC *primitive return nullptr; } scale_param->op_parameter_.type_ = primitive->Type(); - auto param = dynamic_cast(primitive); - if (param == nullptr) { - MS_LOG(ERROR) << "value_as_Scale return nullptr"; - return nullptr; - } + auto param = reinterpret_cast(const_cast(primitive)); scale_param->axis_ = param->GetAxis(); return reinterpret_cast(scale_param); } OpParameter *PopulateGatherParameter(const mindspore::lite::PrimitiveC *primitive) { - auto gather_attr = dynamic_cast(primitive); + auto gather_attr = reinterpret_cast(const_cast(primitive)); auto *gather_param = new (std::nothrow) GatherParameter(); if (gather_param == nullptr) { MS_LOG(ERROR) << "new GatherParameter failed."; @@ -953,7 +969,8 @@ OpParameter *PopulateGatherNdParameter(const mindspore::lite::PrimitiveC *primit return nullptr; } gather_nd_param->op_parameter_.type_ = primitive->Type(); - auto gatherNd_attr = dynamic_cast(primitive); + auto gatherNd_attr = + reinterpret_cast(const_cast(primitive)); gather_nd_param->batchDims_ = gatherNd_attr->GetBatchDims(); return reinterpret_cast(gather_nd_param); } @@ -975,7 +992,7 @@ OpParameter *PopulateSliceParameter(const mindspore::lite::PrimitiveC *primitive MS_LOG(ERROR) << "new SliceParameter failed."; return nullptr; } - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); slice_param->op_parameter_.type_ = primitive->Type(); auto param_begin = param->GetBegin(); auto param_size = param->GetSize(); @@ -997,7 +1014,7 @@ OpParameter *PopulateBroadcastToParameter(const mindspore::lite::PrimitiveC *pri MS_LOG(ERROR) << "new BroadcastToParameter failed."; return nullptr; } - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); broadcast_param->op_parameter_.type_ = primitive->Type(); auto dst_shape = param->GetDstShape(); broadcast_param->shape_size_ = dst_shape.size(); @@ -1028,7 +1045,8 @@ OpParameter *PopulateShapeParameter(const mindspore::lite::PrimitiveC *primitive } OpParameter *PopulateConstantOfShapeParameter(const mindspore::lite::PrimitiveC *primitive) { - auto attr = dynamic_cast(primitive); + auto attr = + reinterpret_cast(const_cast(primitive)); ConstantOfShapeParameter *param = new (std::nothrow) ConstantOfShapeParameter(); if (param == nullptr) { MS_LOG(ERROR) << "new ConstantOfShapeParameter failed."; @@ -1040,7 +1058,8 @@ OpParameter *PopulateConstantOfShapeParameter(const mindspore::lite::PrimitiveC } OpParameter *PopulateReverseParameter(const mindspore::lite::PrimitiveC *primitive) { - auto reverse_attr = dynamic_cast(primitive); + auto reverse_attr = + reinterpret_cast(const_cast(primitive)); ReverseParameter *reverse_param = new (std::nothrow) ReverseParameter(); if (reverse_param == nullptr) { MS_LOG(ERROR) << "new ReverseParameter failed."; @@ -1057,7 +1076,8 @@ OpParameter *PopulateReverseParameter(const mindspore::lite::PrimitiveC *primiti } OpParameter *PopulateUnsqueezeParameter(const mindspore::lite::PrimitiveC *primitive) { - auto unsqueeze_attr = dynamic_cast(primitive); + auto unsqueeze_attr = + reinterpret_cast(const_cast(primitive)); auto *unsqueeze_param = new (std::nothrow) UnsqueezeParameter(); if (unsqueeze_param == nullptr) { MS_LOG(ERROR) << "new ReverseParameter failed."; @@ -1079,7 +1099,7 @@ OpParameter *PopulateStackParameter(const mindspore::lite::PrimitiveC *primitive MS_LOG(ERROR) << "new StackParameter failed."; return nullptr; } - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); stack_param->op_parameter_.type_ = primitive->Type(); stack_param->axis_ = param->GetAxis(); return reinterpret_cast(stack_param); @@ -1091,7 +1111,7 @@ OpParameter *PopulateUnstackParameter(const mindspore::lite::PrimitiveC *primiti MS_LOG(ERROR) << "new UnstackParameter failed."; return nullptr; } - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); unstack_param->op_parameter_.type_ = primitive->Type(); unstack_param->num_ = param->GetNum(); unstack_param->axis_ = param->GetAxis(); @@ -1104,7 +1124,8 @@ OpParameter *PopulateReverseSequenceParameter(const mindspore::lite::PrimitiveC MS_LOG(ERROR) << "new ReverseSequenceParameter failed."; return nullptr; } - auto param = dynamic_cast(primitive); + auto param = + reinterpret_cast(const_cast(primitive)); reverse_sequence_param->op_parameter_.type_ = primitive->Type(); reverse_sequence_param->seq_axis_ = param->GetSeqAxis(); reverse_sequence_param->batch_axis_ = param->GetBatchAxis(); @@ -1127,7 +1148,7 @@ OpParameter *PopulateDepthToSpaceParameter(const mindspore::lite::PrimitiveC *pr MS_LOG(ERROR) << "new DepthToSpaceParameter failed."; return nullptr; } - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); depth_space_param->op_parameter_.type_ = primitive->Type(); depth_space_param->block_size_ = param->GetBlockSize(); return reinterpret_cast(depth_space_param); @@ -1140,7 +1161,7 @@ OpParameter *PopulateSpaceToDepthParameter(const mindspore::lite::PrimitiveC *pr return nullptr; } space_depth_param->op_parameter_.type_ = primitive->Type(); - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); space_depth_param->op_parameter_.type_ = primitive->Type(); space_depth_param->block_size_ = param->GetBlockSize(); if (param->GetFormat() != schema::Format_NHWC) { @@ -1176,7 +1197,7 @@ OpParameter *PopulateResizeParameter(const mindspore::lite::PrimitiveC *primitiv return nullptr; } resize_param->op_parameter_.type_ = primitive->Type(); - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); resize_param->method_ = static_cast(param->GetMethod()); resize_param->new_height_ = param->GetNewHeight(); resize_param->new_width_ = param->GetNewWidth(); @@ -1192,7 +1213,7 @@ OpParameter *PopulateBatchToSpaceParameter(const mindspore::lite::PrimitiveC *pr return nullptr; } batch_space_param->op_parameter_.type_ = primitive->Type(); - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); auto block_shape = param->GetBlockShape(); if (block_shape.size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) { MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; @@ -1216,7 +1237,7 @@ OpParameter *PopulateBatchToSpaceParameter(const mindspore::lite::PrimitiveC *pr } OpParameter *PopulateCropParameter(const mindspore::lite::PrimitiveC *primitive) { - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); auto param_offset = param->GetOffsets(); if (param_offset.size() > CROP_OFFSET_MAX_SIZE) { MS_LOG(ERROR) << "crop_param offset size(" << param_offset.size() << ") should <= " << CROP_OFFSET_MAX_SIZE; @@ -1243,7 +1264,7 @@ OpParameter *PopulateOneHotParameter(const mindspore::lite::PrimitiveC *primitiv return nullptr; } one_hot_param->op_parameter_.type_ = primitive->Type(); - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); if (param == nullptr) { delete (one_hot_param); MS_LOG(ERROR) << "get OneHot param nullptr."; @@ -1270,7 +1291,8 @@ OpParameter *PopulateQuantDTypeCastParameter(const mindspore::lite::PrimitiveC * return nullptr; } parameter->op_parameter_.type_ = primitive->Type(); - auto quant_dtype_cast_param = dynamic_cast(primitive); + auto quant_dtype_cast_param = + reinterpret_cast(const_cast(primitive)); parameter->srcT = quant_dtype_cast_param->GetSrcT(); parameter->dstT = quant_dtype_cast_param->GetDstT(); return reinterpret_cast(parameter); @@ -1313,7 +1335,8 @@ OpParameter *PopulatePriorBoxParameter(const mindspore::lite::PrimitiveC *primit return nullptr; } prior_box_param->op_parameter_.type_ = primitive->Type(); - auto prior_box_attr = dynamic_cast(primitive); + auto prior_box_attr = + reinterpret_cast(const_cast(primitive)); if (prior_box_attr->GetMinSizes().size() > PRIOR_BOX_MAX_NUM) { MS_LOG(ERROR) << "PriorBox min_sizes size exceeds max num " << PRIOR_BOX_MAX_NUM << ", got " @@ -1367,7 +1390,7 @@ OpParameter *PopulateLstmParameter(const mindspore::lite::PrimitiveC *primitive) return nullptr; } lstm_param->op_parameter_.type_ = primitive->Type(); - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); if (param == nullptr) { delete (lstm_param); MS_LOG(ERROR) << "get Lstm param nullptr."; @@ -1384,7 +1407,8 @@ OpParameter *PopulateEmbeddingLookupParameter(const mindspore::lite::PrimitiveC return nullptr; } embedding_lookup_parameter->op_parameter_.type_ = primitive->Type(); - auto param = dynamic_cast(primitive); + auto param = + reinterpret_cast(const_cast(primitive)); embedding_lookup_parameter->max_norm_ = param->GetMaxNorm(); if (embedding_lookup_parameter->max_norm_ < 0) { MS_LOG(ERROR) << "Embedding lookup max norm should be positive number, got " @@ -1412,7 +1436,7 @@ OpParameter *PopulateEluParameter(const mindspore::lite::PrimitiveC *primitive) return nullptr; } elu_parameter->op_parameter_.type_ = primitive->Type(); - auto param = dynamic_cast(primitive); + auto param = reinterpret_cast(const_cast(primitive)); elu_parameter->alpha_ = param->GetAlpha(); return reinterpret_cast(elu_parameter); } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/reduce_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/reduce_base.cc index e5871566a7d..4cfb4a50917 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/reduce_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/reduce_base.cc @@ -14,11 +14,11 @@ * limitations under the License. */ -#include "schema/model_generated.h" +#include "src/runtime/kernel/arm/base/reduce_base.h" #include "src/kernel_registry.h" +#include "schema/model_generated.h" #include "include/errorcode.h" #include "src/runtime/runtime_api.h" -#include "src/runtime/kernel/arm/base/reduce_base.h" #include "src/runtime/kernel/arm/fp32/reduce.h" #include "src/runtime/kernel/arm/int8/reduce_int8.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc index 3f8992caa09..7dcc4b18288 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc @@ -15,10 +15,10 @@ */ #include +#include "src/runtime/kernel/arm/base/resize_base.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "include/errorcode.h" -#include "src/runtime/kernel/arm/base/resize_base.h" #include "src/runtime/kernel/arm/fp32/resize.h" #include "src/runtime/kernel/arm/int8/resize_int8.h" using mindspore::lite::KernelRegistrar; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc index 277d23c8761..57fd2940724 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc @@ -14,9 +14,9 @@ * limitations under the License. */ +#include "src/runtime/kernel/arm/fp32/arithmetic_self.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" -#include "src/runtime/kernel/arm/fp32/arithmetic_self.h" #include "include/errorcode.h" #include "src/runtime/runtime_api.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc index 2a1abbb5058..168a83c8c7e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc @@ -14,9 +14,8 @@ * limitations under the License. */ -#include -#include "schema/model_generated.h" #include "src/kernel_registry.h" +#include "schema/model_generated.h" #include "src/runtime/kernel/arm/fp32/pad.h" #include "include/errorcode.h" #include "src/runtime/kernel/arm/nnacl/errorcode.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.cc index ff6ec3812d7..90a2e5aad82 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/reduce_int8.cc @@ -14,13 +14,12 @@ * limitations under the License. */ -#include +#include "src/runtime/kernel/arm/int8/reduce_int8.h" #include "schema/model_generated.h" -#include "src/runtime/runtime_api.h" #include "src/kernel_registry.h" +#include "src/runtime/runtime_api.h" #include "nnacl/quantization/quantize.h" #include "include/errorcode.h" -#include "src/runtime/kernel/arm/int8/reduce_int8.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; @@ -83,8 +82,7 @@ int ReduceInt8CPUKernel::Init() { last_reducer_ = ReduceSumSquareLastAxis; break; } - default: - MS_LOG(ERROR) << "Reduce unsupported reduce mode: " << mode_; + default:MS_LOG(ERROR) << "Reduce unsupported reduce mode: " << mode_; return RET_ERROR; } if (!InferShapeDone()) { @@ -117,7 +115,7 @@ int ReduceInt8CPUKernel::CalculateQuantArgs() { for (auto i = 0; i < num_axes_; i++) { auto axis = axes_[i]; double reciprocal = 1.0 / in_tensors_.at(0)->shape()[axis]; - QuantMulArg *qm = new (std::nothrow) QuantMulArg; + QuantMulArg *qm = new(std::nothrow) QuantMulArg; if (qm == nullptr) { MS_LOG(ERROR) << "Reduce new QuantMulArg failed."; return RET_NULL_PTR; @@ -135,7 +133,7 @@ int ReduceInt8CPUKernel::CalculateQuantArgs() { if (mode_ == static_cast(schema::ReduceMode_ReduceProd)) { for (auto i = 0; i < num_axes_; i++) { int axis_size = in_tensors_.at(0)->shape()[axes_[i]]; - QuantMulArg *qm = new (std::nothrow) QuantMulArg; + QuantMulArg *qm = new(std::nothrow) QuantMulArg; if (qm == nullptr) { MS_LOG(ERROR) << "ReduceProd new QuantMulArg failed."; return RET_NULL_PTR; @@ -153,7 +151,7 @@ int ReduceInt8CPUKernel::CalculateQuantArgs() { // scale_in * scale_in/scale_out if (mode_ == static_cast(schema::ReduceMode_ReduceSumSquare)) { for (auto i = 0; i < num_axes_ - 1; i++) { - QuantMulArg *qm = new (std::nothrow) QuantMulArg; + QuantMulArg *qm = new(std::nothrow) QuantMulArg; if (qm == nullptr) { MS_LOG(ERROR) << "ReduceProd new QuantMultiplier failed."; return RET_NULL_PTR; @@ -165,7 +163,7 @@ int ReduceInt8CPUKernel::CalculateQuantArgs() { sum_square_multipliers_.push_back(qm); } - QuantMulArg *qm = new (std::nothrow) QuantMulArg; + QuantMulArg *qm = new(std::nothrow) QuantMulArg; if (qm == nullptr) { MS_LOG(ERROR) << "ReduceProd new QuantMultiplier failed."; return RET_NULL_PTR; @@ -338,7 +336,14 @@ int ReduceInt8CPUKernel::CallReduceUnit(int task_id) { int ret; if (!is_last_axis_) { ret = - reducer_(outer_size_, inner_size_, axis_size_, src_data_, dst_data_, &quant_arg_, task_id, context_->thread_num_); + reducer_(outer_size_, + inner_size_, + axis_size_, + src_data_, + dst_data_, + &quant_arg_, + task_id, + context_->thread_num_); } else { ret = last_reducer_(outer_size_, inner_size_, axis_size_, src_data_, last_dst_data_, &quant_arg_, task_id, context_->thread_num_); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.cc index 191e7e5ec1b..fd1581b56c5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.cc @@ -15,9 +15,9 @@ */ #include -#include "schema/model_generated.h" #include "src/kernel_registry.h" #include "nnacl/int8/resize.h" +#include "schema/model_generated.h" #include "include/errorcode.h" #include "src/runtime/kernel/arm/int8/resize_int8.h" #include "src/runtime/runtime_api.h" @@ -41,9 +41,9 @@ int ResizeInt8CPUKernel::Init() { if (ret != RET_OK) { return ret; } - quant_in_ = new (std::nothrow) QuantArg; + quant_in_ = new(std::nothrow) QuantArg; MS_ASSERT(quant_in_); - quant_out_ = new (std::nothrow) QuantArg; + quant_out_ = new(std::nothrow) QuantArg; MS_ASSERT(quant_out_); auto input = in_tensors_.at(0); quant_in_->zp_ = input->GetQuantParams().front().zeroPoint; @@ -52,7 +52,7 @@ int ResizeInt8CPUKernel::Init() { quant_out_->zp_ = output->GetQuantParams().front().zeroPoint; quant_out_->scale_ = output->GetQuantParams().front().scale; - multiplier_ = new (std::nothrow) QuantMulArg; + multiplier_ = new(std::nothrow) QuantMulArg; MS_ASSERT(multiplier_); QuantizeRoundParameter(quant_in_->scale_ / quant_out_->scale_, &multiplier_->multiplier_, &multiplier_->left_shift_, &multiplier_->right_shift_); @@ -101,12 +101,25 @@ int ResizeInt8CPUKernel::RunImpl(int task_id) { bool same_scale = abs(quant_out_->scale_ - quant_in_->scale_) < 1e-6; if (same_zp && same_scale) { ret = - ResizeNearestNeighborInt8Simple(input_data, output_data, input_shape.data(), out_tensors_[0]->shape().data(), - align_corners_, task_id, context_->thread_num_); + ResizeNearestNeighborInt8Simple(input_data, + output_data, + input_shape.data(), + out_tensors_[0]->shape().data(), + align_corners_, + task_id, + context_->thread_num_); } else { ret = - ResizeNearestNeighborInt8(input_data, output_data, input_shape.data(), out_tensors_[0]->shape().data(), - align_corners_, multiplier_, quant_in_, quant_out_, task_id, context_->thread_num_); + ResizeNearestNeighborInt8(input_data, + output_data, + input_shape.data(), + out_tensors_[0]->shape().data(), + align_corners_, + multiplier_, + quant_in_, + quant_out_, + task_id, + context_->thread_num_); } break; } diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index df276550d82..62328753098 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -16,7 +16,6 @@ #include "src/scheduler.h" #include -#include #include "include/errorcode.h" #include "src/kernel_registry.h" #include "src/common/graph_util.h" @@ -116,6 +115,7 @@ int Scheduler::InferShape(const lite::Model *model, std::vectorset_is_model_output(IsContain(graph_output_node_indexes, size_t(i))); kernels->emplace_back(kernel); } + return RET_OK; } diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 2a8dbefad6c..8c1649b0ab9 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -180,7 +180,7 @@ set(TEST_LITE_SRC ${LITE_DIR}/src/runtime/workspace_pool.cc ${LITE_DIR}/src/runtime/parallel_executor.cc ${LITE_DIR}/src/ir/tensor.cc - ${LITE_DIR}/src/ir/primitive_t_value.cc +# ${LITE_DIR}/src/ir/primitive_t_value.cc ${LITE_DIR}/src/context.cc ${LITE_DIR}/src/executor.cc ${LITE_DIR}/src/kernel_registry.cc @@ -219,6 +219,7 @@ if (SUPPORT_GPU) endif() ### converter if(BUILD_CONVERTER) + add_definitions(-DPRIMITIVE_WRITEABLE) file(GLOB_RECURSE TEST_CASE_TFLITE_PARSERS_SRC ${TEST_DIR}/ut/tools/converter/parser/tflite/*.cc ) diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/add_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/add_int8_tests.cc index 710afbcb25d..edb53248e79 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/add_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/add_int8_tests.cc @@ -16,8 +16,8 @@ #include #include +#include "schema/inner/model_generated.h" #include "common/common_test.h" -#include "mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h" #include "mindspore/lite/src/kernel_registry.h" #include "mindspore/lite/include/context.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc index 1fce620899b..5e50ae45a22 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc @@ -15,6 +15,7 @@ */ #include +#include "schema/inner/model_generated.h" #include "utils/log_adapter.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_self_parameter.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/batchnorm_int8_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/batchnorm_int8_test.cc index ee246454560..42c91c3dfae 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/batchnorm_int8_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/batchnorm_int8_test.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include +#include "schema/inner/model_generated.h" #include "mindspore/core/utils/log_adapter.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/batchnorm_parameter.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/bias_add_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/bias_add_int8_tests.cc index d65632c5fd0..6312d5f2f29 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/bias_add_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/bias_add_int8_tests.cc @@ -16,6 +16,7 @@ #include #include +#include "schema/inner/model_generated.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.h" #include "mindspore/lite/src/kernel_registry.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/concat_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/concat_int8_tests.cc index 3cd925a86c0..1b04a943e34 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/concat_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/concat_int8_tests.cc @@ -15,6 +15,7 @@ */ #include +#include "schema/inner/model_generated.h" #include "utils/log_adapter.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/concat_parameter.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc index 0d02008f262..12ae5e6650a 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc @@ -15,6 +15,7 @@ */ #include +#include "schema/inner/model_generated.h" #include "utils/log_adapter.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/crop_parameter.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc index a678f0d9a29..2b9e2810179 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc @@ -16,6 +16,7 @@ #include #include +#include "schema/inner/model_generated.h" #include "common/common_test.h" #include "src/common/file_utils.h" #include "mindspore/lite/src/kernel_registry.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/div_int8_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/div_int8_test.cc index 48206673fc5..4cdb92440ca 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/div_int8_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/div_int8_test.cc @@ -16,6 +16,7 @@ #include #include +#include "schema/inner/model_generated.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/int8/div_int8.h" #include "mindspore/lite/src/kernel_registry.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc index 52666f62727..e6dfdb831eb 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "schema/inner/model_generated.h" #include "utils/log_adapter.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc index a564bf87e17..5caeff4be8c 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc @@ -16,6 +16,7 @@ #include #include +#include "schema/inner/model_generated.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/fp32/activation.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/activation.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc index 22a1af6a772..d2ba6848832 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#include "schema/inner/model_generated.h" #include "utils/log_adapter.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc index 9830e165c67..6732ae4014b 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc @@ -15,6 +15,7 @@ */ #include +#include "schema/inner/model_generated.h" #include "utils/log_adapter.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/mul_parameter.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc index 4676f4cb989..00220e4ace2 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc @@ -15,6 +15,7 @@ */ #include +#include "schema/inner/model_generated.h" #include "include/context.h" #include "src/ir/tensor.h" #include "common/common_test.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/power_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/power_int8_tests.cc index 9fa8addd341..03f0cd2d561 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/power_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/power_int8_tests.cc @@ -16,6 +16,7 @@ #include #include +#include "schema/inner/model_generated.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/int8/power_int8.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/power_parameter.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc index e987a74c33f..a465e890d34 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc @@ -15,6 +15,7 @@ */ #include +#include "schema/inner/model_generated.h" #include "utils/log_adapter.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/prelu_parameter.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc index ad89c45a0bb..cc89cf0bba6 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc @@ -15,6 +15,7 @@ */ #include #include +#include "schema/inner/model_generated.h" #include "utils/log_adapter.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/reduce_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/reduce_int8_tests.cc index f3e78a50147..59e0f080323 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/reduce_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/reduce_int8_tests.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include +#include "schema/inner/model_generated.h" #include "utils/log_adapter.h" #include "common/common_test.h" #include "src/ir/tensor.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/relux_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/relux_int8_tests.cc index 920902b2000..c584bfa3107 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/relux_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/relux_int8_tests.cc @@ -16,6 +16,7 @@ #include #include +#include "schema/inner/model_generated.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.h" #include "mindspore/lite/src/kernel_registry.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/reshape_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/reshape_int8_tests.cc index 9a2131dc73a..583c03fe6bd 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/reshape_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/reshape_int8_tests.cc @@ -15,6 +15,7 @@ */ #include +#include "schema/inner/model_generated.h" #include "utils/log_adapter.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/reshape_parameter.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/resize_bilinear_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/resize_bilinear_int8_tests.cc index 7ba5ec8b1a1..87a128a4c14 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/resize_bilinear_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/resize_bilinear_int8_tests.cc @@ -15,6 +15,7 @@ */ #include +#include "schema/inner/model_generated.h" #include "include/context.h" #include "src/ir/tensor.h" #include "common/common_test.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/resize_nearest_neighbor_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/resize_nearest_neighbor_int8_tests.cc index ce29982f178..5189cbb7e74 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/resize_nearest_neighbor_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/resize_nearest_neighbor_int8_tests.cc @@ -15,6 +15,7 @@ */ #include +#include "schema/inner/model_generated.h" #include "include/context.h" #include "src/ir/tensor.h" #include "common/common_test.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/sigmoid_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/sigmoid_int8_tests.cc index ff8e4141627..f4de2f8fcb9 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/sigmoid_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/sigmoid_int8_tests.cc @@ -15,6 +15,7 @@ */ #include +#include "schema/inner/model_generated.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/activation.h" #include "mindspore/lite/src/kernel_registry.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/slice_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/slice_int8_tests.cc index 5e27ec6aaa1..50857b264e8 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/slice_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/slice_int8_tests.cc @@ -16,6 +16,7 @@ #include #include +#include "schema/inner/model_generated.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/int8/slice_int8.h" #include "mindspore/lite/src/kernel_registry.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc index c87a3ede504..89776a16263 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc @@ -16,6 +16,7 @@ #include #include +#include "schema/inner/model_generated.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/softmax_parameter.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc index a224b9cdadb..e73242c7edb 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc @@ -15,6 +15,7 @@ */ #include +#include "schema/inner/model_generated.h" #include "utils/log_adapter.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/split_parameter.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/squeeze_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/squeeze_int8_tests.cc index 6bb7fe9cacd..cabb2f4b023 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/squeeze_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/squeeze_int8_tests.cc @@ -15,6 +15,7 @@ */ #include +#include "schema/inner/model_generated.h" #include "utils/log_adapter.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/squeeze_parameter.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/sub_int_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/sub_int_tests.cc index 28966e03b90..9db48749782 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/sub_int_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/sub_int_tests.cc @@ -16,6 +16,7 @@ #include #include +#include "schema/inner/model_generated.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/int8/sub_int8.h" #include "mindspore/lite/src/kernel_registry.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc index 302a3b388e1..d4d3656ea12 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc @@ -16,6 +16,7 @@ #include #include +#include "schema/inner/model_generated.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/topk.h" #include "mindspore/lite/src/kernel_registry.h" diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/unsqueeze_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/unsqueeze_int8_tests.cc index 41acb551880..9e0835f5bb2 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/unsqueeze_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/unsqueeze_int8_tests.cc @@ -15,6 +15,7 @@ */ #include +#include "schema/inner/model_generated.h" #include "utils/log_adapter.h" #include "common/common_test.h" #include "mindspore/lite/src/runtime/kernel/arm/nnacl/unsqueeze_parameter.h" diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index ed0619cfc32..e63a4d59ca7 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -110,7 +110,7 @@ bool AnfExporter::AddOutPutIfReturn(const std::unique_ptr &m } int AnfExporter::ConvertQuantParam(const std::unique_ptr &meta_graph, - const std::shared_ptr primitive, + const std::shared_ptr primitive, const std::unique_ptr &dst_node) { MS_ASSERT(meta_graph != nullptr); MS_ASSERT(primitive != nullptr); @@ -121,7 +121,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr &me MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam"; // activation auto input_quant_params = primitive->GetInputQuantParams(); - auto node_type = primitive->GetPrimitiveT()->value.type; + auto node_type = (schema::PrimitiveType) primitive->Type(); for (size_t i = 0; i < input_quant_params.size(); i++) { if (i >= dst_node->inputIndex.size()) { MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size() @@ -160,7 +160,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr &me } if (dst_node->quantType != schema::QuantType_AwareTraining && !(node_type == schema::PrimitiveType_QuantDTypeCast && - primitive->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { + primitive->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { tensor_output->dataType = kNumberTypeInt8; } } @@ -186,7 +186,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) { auto cnodes = func_graph->GetOrderedCnodes(); auto meta_graphT = std::make_unique(); for (const auto &cnode : cnodes) { - auto primitiveT_value = GetValueNode>(cnode->input(0)); + auto primitiveT_value = GetValueNode>(cnode->input(0)); if (primitiveT_value == nullptr) { MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; return nullptr; @@ -431,15 +431,11 @@ bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType return false; } - const auto &prim = GetValueNode>(cnode->input(0)); + const auto &prim = GetValueNode>(cnode->input(0)); if (prim == nullptr) { return false; } - auto *primitiveT = prim->GetPrimitiveT(); - if (primitiveT == nullptr) { - return false; - } - return primitiveT->value.type == type; + return (schema::PrimitiveType) prim->Type() == type; } schema::MetaGraphT *Export(const FuncGraphPtr &func_graph) { diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.h b/mindspore/lite/tools/anf_exporter/anf_exporter.h index d8917f4417b..8b6c30dd412 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.h +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.h @@ -22,7 +22,7 @@ #include #include #include "schema/inner/model_generated.h" -#include "src/ir/primitive_t_value.h" +#include "src/ops/primitive_c.h" #include "ir/func_graph.h" namespace mindspore::lite { @@ -48,8 +48,7 @@ class AnfExporter { void SetGraphInputIndex(const std::unique_ptr &meta_graphT); bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type); int ConvertQuantParam(const std::unique_ptr &meta_graph, - const std::shared_ptr primitive, - const std::unique_ptr &dst_node); + const std::shared_ptr primitive, const std::unique_ptr &dst_node); private: std::map node_id_map_; diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.cc index 12eae7f4f06..8310366196a 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.cc @@ -21,7 +21,7 @@ #include "ir/primitive.h" namespace mindspore::lite { -int AnfActivationPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfActivationPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); @@ -35,8 +35,8 @@ int AnfActivationPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue * primitive->value.type = schema::PrimitiveType_Activation; primitive->value.value = attr.release(); - MS_ASSERT(primitiveTValuePtr != nullptr); - primitiveTValuePtr->SetPrimitiveT(primitive.release()); + MS_ASSERT(primitiveCPtr != nullptr); + primitiveCPtr->SetPrimitiveT(primitive.release()); return 0; } AnfNodePopulaterRegistrar anfReLUPopulater("ReLU", new AnfActivationPopulater()); diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.h index d976a8a4e9a..f2f18f9c0d1 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.h @@ -18,12 +18,14 @@ #define MINDSPORE_ANF_ACTIVATION_PARSER_H #include "tools/anf_importer/anf_populater/anf_node_populater.h" #include + + namespace mindspore::lite { class AnfActivationPopulater : public AnfNodePopulater { public: AnfActivationPopulater() = default; ~AnfActivationPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.cc index cca157e370d..64deba38aca 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.cc @@ -21,15 +21,15 @@ #include "ir/primitive.h" namespace mindspore::lite { -int AnfBatchnormPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfBatchnormPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); attr->epsilon = GetValue(prim->GetAttr("epsilon")); primitive->value.type = schema::PrimitiveType_FusedBatchNorm; primitive->value.value = attr.release(); - MS_ASSERT(primitiveTValuePtr != nullptr); - primitiveTValuePtr->SetPrimitiveT(primitive.release()); + MS_ASSERT(primitiveCPtr != nullptr); + primitiveCPtr->SetPrimitiveT(primitive.release()); return 0; } AnfNodePopulaterRegistrar anfBatchnormPopulater("BatchNorm", new AnfBatchnormPopulater()); diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.h index 92fb87e6bb7..d4c39e0b710 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.h @@ -17,12 +17,13 @@ #define MINDSPORE_ANF_BATCHNORM_PARSER_H #include "tools/anf_importer/anf_populater/anf_node_populater.h" #include + namespace mindspore::lite { class AnfBatchnormPopulater : public AnfNodePopulater { public: AnfBatchnormPopulater() = default; ~AnfBatchnormPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.cc index e72ce6dba8a..856720578f1 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.cc @@ -21,15 +21,15 @@ #include "ir/primitive.h" namespace mindspore::lite { -int AnfBiasAddPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfBiasAddPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); attr->axis = {0}; primitive->value.type = schema::PrimitiveType_BiasAdd; primitive->value.value = attr.release(); - MS_ASSERT(primitiveTValuePtr != nullptr); - primitiveTValuePtr->SetPrimitiveT(primitive.release()); + MS_ASSERT(primitiveCPtr != nullptr); + primitiveCPtr->SetPrimitiveT(primitive.release()); return 0; } diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.h index 508e47ef04e..bcb3db8cc7f 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.h @@ -17,12 +17,13 @@ #define MINDSPORE_ANF_BIASADD_PARSER_H #include "tools/anf_importer/anf_populater/anf_node_populater.h" #include + namespace mindspore::lite { class AnfBiasAddPopulater : public AnfNodePopulater { public: AnfBiasAddPopulater() = default; ~AnfBiasAddPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.cc index 0294865bd26..50bc16106e3 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.cc @@ -23,7 +23,7 @@ #include "ir/primitive.h" namespace mindspore::lite { -int AnfConcatPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfConcatPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); @@ -31,8 +31,8 @@ int AnfConcatPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *prim attr->axis = prim_axis; primitive->value.type = schema::PrimitiveType_Concat; primitive->value.value = attr.release(); - MS_ASSERT(primitiveTValuePtr != nullptr); - primitiveTValuePtr->SetPrimitiveT(primitive.release()); + MS_ASSERT(primitiveCPtr != nullptr); + primitiveCPtr->SetPrimitiveT(primitive.release()); return 0; } diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.h index b10845bbfb7..6abd8e907ba 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.h @@ -18,12 +18,13 @@ #define MINDSPORE_ANF_CONCAT_PARSER_H #include "tools/anf_importer/anf_populater/anf_node_populater.h" #include + namespace mindspore::lite { class AnfConcatPopulater : public AnfNodePopulater { public: AnfConcatPopulater() = default; ~AnfConcatPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtrr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtrr, const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc index 6e4a50e29c7..cbef4d5047f 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc @@ -214,9 +214,9 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim, } } -int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { - MS_ASSERT(primitiveTValuePtr != nullptr); + MS_ASSERT(primitiveCPtr != nullptr); auto primitive = std::make_unique(); int group = GetValue(prim->GetAttr("group")); @@ -225,14 +225,14 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primit } else { PopulaterConv2DSingleGroup(prim, primitive, group); } - primitiveTValuePtr->SetPrimitiveT(primitive.release()); + primitiveCPtr->SetPrimitiveT(primitive.release()); - if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) { + if (primitiveCPtr->GetQuantType() == schema::QuantType_AwareTraining) { std::vector> vecInputQuantParam; std::vector> vecOutputQuantParam; PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); - primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam); - primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam); + primitiveCPtr->SetInputQuantParam(vecInputQuantParam); + primitiveCPtr->SetOutputQuantParam(vecOutputQuantParam); } return 0; } diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h index 3e32d013d2a..f2c3d7b30dc 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h @@ -30,7 +30,7 @@ class AnfConvPopulater : public AnfNodePopulater { public: AnfConvPopulater() = default; ~AnfConvPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) override; private: diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc index 583e8df4233..abc78559c41 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc @@ -116,7 +116,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam( } } -int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); @@ -178,15 +178,15 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; primitive->value.value = attr.release(); - MS_ASSERT(primitiveTValuePtr != nullptr); - primitiveTValuePtr->SetPrimitiveT(primitive.release()); + MS_ASSERT(primitiveCPtr != nullptr); + primitiveCPtr->SetPrimitiveT(primitive.release()); - if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) { + if (primitiveCPtr->GetQuantType() == schema::QuantType_AwareTraining) { std::vector> vecInputQuantParam; std::vector> vecOutputQuantParam; PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); - primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam); - primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam); + primitiveCPtr->SetInputQuantParam(vecInputQuantParam); + primitiveCPtr->SetOutputQuantParam(vecOutputQuantParam); } return 0; } diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h index 005f132516b..110e5ae5502 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h @@ -24,7 +24,7 @@ class AnfDepwiseconv2DPopulater : public AnfNodePopulater { public: AnfDepwiseconv2DPopulater() = default; ~AnfDepwiseconv2DPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) override; private: diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.cc index 4c88cce9da1..f9d7cfaea65 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.cc @@ -22,14 +22,14 @@ #include "ir/primitive.h" namespace mindspore::lite { -int AnfDequantPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfDequantPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; primitive->value.value = attr.release(); - MS_ASSERT(primitiveTValuePtr != nullptr); - primitiveTValuePtr->SetPrimitiveT(primitive.release()); + MS_ASSERT(primitiveCPtr != nullptr); + primitiveCPtr->SetPrimitiveT(primitive.release()); return 0; } AnfNodePopulaterRegistrar anfDequantPopulater("Dequant", new AnfDequantPopulater()); diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.h index 77bb3f2b5f0..20f55f8da60 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.h @@ -17,12 +17,13 @@ #define MINDSPORE_ANF_DEQUANT_PARSER_H #include "tools/anf_importer/anf_populater/anf_node_populater.h" #include + namespace mindspore::lite { class AnfDequantPopulater : public AnfNodePopulater { public: AnfDequantPopulater() = default; ~AnfDequantPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.cc index db80e41463e..ce5645fb350 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.cc @@ -21,14 +21,14 @@ #include "ir/primitive.h" namespace mindspore::lite { -int AnfFlattenPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfFlattenPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); primitive->value.type = schema::PrimitiveType_Flatten; primitive->value.value = attr.release(); - MS_ASSERT(primitiveTValuePtr != nullptr); - primitiveTValuePtr->SetPrimitiveT(primitive.release()); + MS_ASSERT(primitiveCPtr != nullptr); + primitiveCPtr->SetPrimitiveT(primitive.release()); return 0; } diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.h index 5366873fc1a..01c0f0a1dd1 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.h @@ -17,12 +17,13 @@ #define MINDSPORE_ANF_FLATTEN_PARSER_H #include "tools/anf_importer/anf_populater/anf_node_populater.h" #include + namespace mindspore::lite { class AnfFlattenPopulater : public AnfNodePopulater { public: AnfFlattenPopulater() = default; ~AnfFlattenPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_make_tuple_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_make_tuple_populater.cc index 5fae271e4a5..16d40c11039 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_make_tuple_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_make_tuple_populater.cc @@ -21,14 +21,14 @@ #include "ir/primitive.h" namespace mindspore::lite { -int AnfMakeTuplePopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfMakeTuplePopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); primitive->value.type = schema::PrimitiveType_MakeTuple; primitive->value.value = attr.release(); - MS_ASSERT(primitiveTValuePtr != nullptr); - primitiveTValuePtr->SetPrimitiveT(primitive.release()); + MS_ASSERT(primitiveCPtr != nullptr); + primitiveCPtr->SetPrimitiveT(primitive.release()); return 0; } AnfNodePopulaterRegistrar anfMakeTuplePopulater("make_tuple", new AnfMakeTuplePopulater()); diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_make_tuple_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_make_tuple_populater.h index 973ae8fc15c..d8b283e41ec 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_make_tuple_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_make_tuple_populater.h @@ -17,12 +17,13 @@ #define MINDSPORE_ANF_MAKE_TUPLE_PARSER_H #include "tools/anf_importer/anf_populater/anf_node_populater.h" #include + namespace mindspore::lite { class AnfMakeTuplePopulater : public AnfNodePopulater { public: AnfMakeTuplePopulater() = default; ~AnfMakeTuplePopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.cc index de3a84ee37f..261f0cff239 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.cc @@ -101,7 +101,7 @@ void AnfMatmulPopulater::PopulaterQuantParam( } } -int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); @@ -110,14 +110,14 @@ int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *prim primitive->value.type = schema::PrimitiveType_MatMul; primitive->value.value = attr.release(); - MS_ASSERT(primitiveTValuePtr != nullptr); - primitiveTValuePtr->SetPrimitiveT(primitive.release()); - if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) { + MS_ASSERT(primitiveCPtr != nullptr); + primitiveCPtr->SetPrimitiveT(primitive.release()); + if (primitiveCPtr->GetQuantType() == schema::QuantType_AwareTraining) { std::vector> vecInputQuantParam; std::vector> vecOutputQuantParam; PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); - primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam); - primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam); + primitiveCPtr->SetInputQuantParam(vecInputQuantParam); + primitiveCPtr->SetOutputQuantParam(vecOutputQuantParam); } return 0; } diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.h index d99cf57339d..e67b2e7d157 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.h @@ -17,12 +17,13 @@ #define MINDSPORE_ANF_MATMUL_PARSER_H #include "tools/anf_importer/anf_populater/anf_node_populater.h" #include + namespace mindspore::lite { class AnfMatmulPopulater : public AnfNodePopulater { public: AnfMatmulPopulater() = default; ~AnfMatmulPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) override; private: diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.cc index 0ba673b49aa..bc99452582d 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.cc @@ -21,14 +21,14 @@ #include "ir/primitive.h" namespace mindspore::lite { -int AnfMulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfMulPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); primitive->value.type = schema::PrimitiveType_Mul; primitive->value.value = attr.release(); - MS_ASSERT(primitiveTValuePtr != nullptr); - primitiveTValuePtr->SetPrimitiveT(primitive.release()); + MS_ASSERT(primitiveCPtr != nullptr); + primitiveCPtr->SetPrimitiveT(primitive.release()); return 0; } AnfNodePopulaterRegistrar anfMulPopulater("Mul", new AnfMulPopulater()); diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.h index 30eb3f7173d..1dad59e86ba 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.h @@ -17,12 +17,13 @@ #define MINDSPORE_ANF_ACTIVATION_PARSER_H #include "tools/anf_importer/anf_populater/anf_node_populater.h" #include + namespace mindspore::lite { class AnfMulPopulater : public AnfNodePopulater { public: AnfMulPopulater() = default; ~AnfMulPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater.h index 6270a1ea18f..d0290fe6964 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater.h @@ -19,8 +19,9 @@ #include #include "ir/anf.h" -#include "src/ir/primitive_t_value.h" #include "schema/inner/model_generated.h" +#include "src/ops/primitive_c.h" + namespace mindspore::lite { constexpr int kAnfPopulaterOne = 1; constexpr int kAnfPopulaterTwo = 2; @@ -30,7 +31,7 @@ class AnfNodePopulater { AnfNodePopulater() = default; virtual ~AnfNodePopulater() = default; - virtual int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + virtual int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) = 0; }; diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.cc index 5f06a84d2a9..1e2cc41a345 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.cc @@ -22,7 +22,7 @@ #include "ir/primitive.h" namespace mindspore::lite { -int AnfPoolPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfPoolPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); @@ -60,8 +60,8 @@ int AnfPoolPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primit primitive->value.type = schema::PrimitiveType_Pooling; primitive->value.value = attr.release(); - MS_ASSERT(primitiveTValuePtr != nullptr); - primitiveTValuePtr->SetPrimitiveT(primitive.release()); + MS_ASSERT(primitiveCPtr != nullptr); + primitiveCPtr->SetPrimitiveT(primitive.release()); return 0; } AnfNodePopulaterRegistrar anfMaxPoolPopulater("MaxPool", new AnfPoolPopulater()); diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.h index 7aefc444092..d2228dc45d0 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.h @@ -22,7 +22,7 @@ class AnfPoolPopulater : public AnfNodePopulater { public: AnfPoolPopulater() = default; ~AnfPoolPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.cc index 98b858180db..69c0d514fac 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.cc @@ -22,14 +22,14 @@ #include "ir/primitive.h" namespace mindspore::lite { -int AnfQuantPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfQuantPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize; primitive->value.value = attr.release(); - MS_ASSERT(primitiveTValuePtr != nullptr); - primitiveTValuePtr->SetPrimitiveT(primitive.release()); + MS_ASSERT(primitiveCPtr != nullptr); + primitiveCPtr->SetPrimitiveT(primitive.release()); return 0; } AnfNodePopulaterRegistrar anfQuantPopulater("Quant", new AnfQuantPopulater()); diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.h index e7eec3cb09b..28d3c96229a 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.h @@ -22,7 +22,7 @@ class AnfQuantPopulater : public AnfNodePopulater { public: AnfQuantPopulater() = default; ~AnfQuantPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.cc index 9c24375dd52..a0818db075f 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.cc @@ -21,7 +21,7 @@ #include "ir/primitive.h" namespace mindspore::lite { -int AnfReduceMeanPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfReduceMeanPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); @@ -50,8 +50,8 @@ int AnfReduceMeanPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue * primitive->value.type = schema::PrimitiveType_Reduce; primitive->value.value = attr.release(); - MS_ASSERT(primitiveTValuePtr != nullptr); - primitiveTValuePtr->SetPrimitiveT(primitive.release()); + MS_ASSERT(primitiveCPtr != nullptr); + primitiveCPtr->SetPrimitiveT(primitive.release()); return 0; } AnfNodePopulaterRegistrar anfReduceMeanPopulater("ReduceMean", new AnfReduceMeanPopulater()); diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.h index ba4a1c6b3a8..15e07c08c72 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.h @@ -22,7 +22,7 @@ class AnfReduceMeanPopulater : public AnfNodePopulater { public: AnfReduceMeanPopulater() = default; ~AnfReduceMeanPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.cc index ce86fc780b0..1252c068416 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.cc @@ -21,7 +21,7 @@ #include "ir/primitive.h" namespace mindspore::lite { -int AnfReshapePopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfReshapePopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); @@ -45,8 +45,8 @@ int AnfReshapePopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *pri primitive->value.type = schema::PrimitiveType_Reshape; primitive->value.value = attr.release(); - MS_ASSERT(primitiveTValuePtr != nullptr); - primitiveTValuePtr->SetPrimitiveT(primitive.release()); + MS_ASSERT(primitiveCPtr != nullptr); + primitiveCPtr->SetPrimitiveT(primitive.release()); return 0; } diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.h index fd2d35a8753..a0ce7513a19 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.h @@ -23,7 +23,7 @@ class AnfReshapePopulater : public AnfNodePopulater { public: AnfReshapePopulater() = default; ~AnfReshapePopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.cc index 4b8d951db89..428ac6d404e 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.cc @@ -21,14 +21,14 @@ #include "ir/primitive.h" namespace mindspore::lite { -int AnfTensorAddPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfTensorAddPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); primitive->value.type = schema::PrimitiveType_Add; primitive->value.value = attr.release(); - MS_ASSERT(primitiveTValuePtr != nullptr); - primitiveTValuePtr->SetPrimitiveT(primitive.release()); + MS_ASSERT(primitiveCPtr != nullptr); + primitiveCPtr->SetPrimitiveT(primitive.release()); return 0; } AnfNodePopulaterRegistrar anfTensorAddPopulater("TensorAdd", new AnfTensorAddPopulater()); diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.h index 7b990bbb850..e681cef48e6 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.h @@ -22,7 +22,7 @@ class AnfTensorAddPopulater : public AnfNodePopulater { public: AnfTensorAddPopulater() = default; ~AnfTensorAddPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.cc index 9eb790f4d68..4e6c1d65c85 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.cc @@ -22,7 +22,7 @@ #include "ir/primitive.h" namespace mindspore::lite { -int AnfTransposePopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfTransposePopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); @@ -46,8 +46,8 @@ int AnfTransposePopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *p primitive->value.type = schema::PrimitiveType_Transpose; primitive->value.value = attr.release(); - MS_ASSERT(primitiveTValuePtr != nullptr); - primitiveTValuePtr->SetPrimitiveT(primitive.release()); + MS_ASSERT(primitiveCPtr != nullptr); + primitiveCPtr->SetPrimitiveT(primitive.release()); return 0; } AnfNodePopulaterRegistrar anfTransposePopulater("Transpose", new AnfTransposePopulater()); diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.h index 60281c8fd1f..00699307acc 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.h @@ -22,7 +22,7 @@ class AnfTransposePopulater : public AnfNodePopulater { public: AnfTransposePopulater() = default; ~AnfTransposePopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.cc index a63f244ce00..d903fe41239 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.cc @@ -21,14 +21,14 @@ #include "ir/primitive.h" namespace mindspore::lite { -int AnfTupleGetItemPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, +int AnfTupleGetItemPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); primitive->value.type = schema::PrimitiveType_TupleGetItem; primitive->value.value = attr.release(); - MS_ASSERT(primitiveTValuePtr != nullptr); - primitiveTValuePtr->SetPrimitiveT(primitive.release()); + MS_ASSERT(primitiveCPtr != nullptr); + primitiveCPtr->SetPrimitiveT(primitive.release()); return 0; } AnfNodePopulaterRegistrar anfTupleGetItemPopulater("tuple_getitem", new AnfTupleGetItemPopulater()); diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.h index 8e0c835d498..caf670e67e7 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.h @@ -22,7 +22,7 @@ class AnfTupleGetItemPopulater : public AnfNodePopulater { public: AnfTupleGetItemPopulater() = default; ~AnfTupleGetItemPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, + int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, const std::vector &inputs) override; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc index 233400f7389..fe4d64eb93a 100644 --- a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc +++ b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc @@ -23,7 +23,6 @@ #include "utils/log_adapter.h" #include "include/errorcode.h" - namespace mindspore::lite { int AnfImporterFromMetaGraphT::ConverterConstTensor() { MS_ASSERT(nullptr != meta_graph_); @@ -61,17 +60,17 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() { param_value->set_tensor_addr(tensor_data); param_value->set_tensor_size(size); } -// if (!tensor->quantParams.empty()) { -// std::unique_ptr quantParam = std::make_unique(); -// quantParam->scale = tensor->quantParams[0]->scale; -// quantParam->zeroPoint = tensor->quantParams[0]->zeroPoint; -// quantParam->min = tensor->quantParams[0]->min; -// quantParam->max = tensor->quantParams[0]->max; -// quantParam->narrowRange = tensor->quantParams[0]->narrowRange; -// quantParam->numBits = tensor->quantParams[0]->numBits; -// quantParam->inited = tensor->quantParams[0]->inited; -// param_value->set_quant_param(quantParam); -// } + // if (!tensor->quantParams.empty()) { + // std::unique_ptr quantParam = std::make_unique(); + // quantParam->scale = tensor->quantParams[0]->scale; + // quantParam->zeroPoint = tensor->quantParams[0]->zeroPoint; + // quantParam->min = tensor->quantParams[0]->min; + // quantParam->max = tensor->quantParams[0]->max; + // quantParam->narrowRange = tensor->quantParams[0]->narrowRange; + // quantParam->numBits = tensor->quantParams[0]->numBits; + // quantParam->inited = tensor->quantParams[0]->inited; + // param_value->set_quant_param(quantParam); + // } parameter->set_default_param(param_value); AddNode(i, parameter); } @@ -81,25 +80,25 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() { ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr &cNode) { MS_ASSERT(nullptr != meta_graph_); MS_ASSERT(nullptr != cNode); - auto primTValue = std::make_shared(cNode->primitive.release()); + auto primitiveCValue = std::make_shared(cNode->primitive.release()); cNode->primitive = nullptr; // add quant parameter if (cNode->quantType == schema::QuantType_AwareTraining) { - primTValue->SetQuantType(cNode->quantType); + primitiveCValue->SetQuantType(cNode->quantType); for (int index : cNode->inputIndex) { if (meta_graph_->allTensors[index]->quantParams.size() > 0) { std::vector quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; - primTValue->AddInputQuantParam(quant_params); + primitiveCValue->AddInputQuantParam(quant_params); } } for (int index : cNode->outputIndex) { if (meta_graph_->allTensors[index]->quantParams.size() > 0) { std::vector quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; - primTValue->AddOutputQuantParam(quant_params); + primitiveCValue->AddOutputQuantParam(quant_params); } } } - auto value_node = NewValueNode(primTValue); + auto value_node = NewValueNode(primitiveCValue); return value_node; } diff --git a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.h b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.h index 0e8f3e8ca2e..bed226c4d43 100644 --- a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.h +++ b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.h @@ -21,7 +21,7 @@ #include #include "schema/inner/model_generated.h" #include "tools/anf_importer/anf_importer.h" -#include "src/ir/primitive_t_value.h" +#include "src/ops/primitive_c.h" #include "abstract/abstract_value.h" namespace mindspore::lite { diff --git a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc index 0c1e046bebf..b6a39f382bf 100644 --- a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc +++ b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc @@ -60,16 +60,16 @@ enum ParseForm : int { }; static std::map kParseTypeSwitchMap{ - {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; + {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; static std::unordered_map kDefaultValueSwitchMap{ - {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, - {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, - {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, - {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, - {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, - {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, - {onnx::TensorProto_DataType_STRING, kObjectTypeString}, + {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, + {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, + {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, + {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, + {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, + {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, + {onnx::TensorProto_DataType_STRING, kObjectTypeString}, }; #define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ @@ -228,8 +228,7 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &pr auto value = prim->GetAttr(attr_name); break; } - default: - MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; + default:MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; return false; } return true; @@ -292,8 +291,7 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con case FORM_PARSE_TENSOR: { return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor); } - default: - MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; + default:MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; return false; } } @@ -357,8 +355,7 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val value_ptr = std::make_shared(elems); break; } - default: - MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; + default:MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; return false; } auto new_value_node = NewValueNode(value_ptr); @@ -396,8 +393,7 @@ bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_at case FORM_PARSE_TYPE: { return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); } - default: - MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name"; + default:MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name"; return false; } } @@ -481,11 +477,11 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out return nullptr; } auto primitiveT = std::make_unique(); - std::shared_ptr primitiveTValuePtr = std::make_shared(primitiveT.release()); - primitiveTValuePtr->SetQuantType(quantType); - node_parser->Populate(prim, primitiveTValuePtr.get(), inputs); - MS_ASSERT(primitiveTValuePtr != nullptr); - inputs.insert(inputs.begin(), NewValueNode(primitiveTValuePtr)); + std::shared_ptr primitiveCPtr = std::make_shared(primitiveT.release()); + primitiveCPtr->SetQuantType(quantType); + node_parser->Populate(prim, primitiveCPtr.get(), inputs); + MS_ASSERT(primitiveCPtr != nullptr); + inputs.insert(inputs.begin(), NewValueNode(primitiveCPtr)); CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(cnode_ptr); if (node_type == "LayerNorm") { @@ -523,9 +519,9 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output auto primitiveT = std::make_unique(); MS_ASSERT(primitiveT != nullptr); primitiveT->value.type = schema::PrimitiveType_MakeTuple; - std::shared_ptr primitiveTValuePtr = std::make_shared(primitiveT.release()); - MS_ASSERT(primitiveTValuePtr != nullptr); - inputs.push_back(NewValueNode(primitiveTValuePtr)); + std::shared_ptr primitiveCPtr = std::make_shared(primitiveT.release()); + MS_ASSERT(primitiveCPtr != nullptr); + inputs.push_back(NewValueNode(primitiveCPtr)); AbstractBasePtrList elem; for (int out_size = 0; out_size < importProto.output_size(); ++out_size) { const onnx::ValueInfoProto &output_node = importProto.output(out_size); @@ -539,7 +535,7 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output auto primReturn = std::make_unique(); MS_ASSERT(primReturn != nullptr); primReturn->value.type = schema::PrimitiveType_Return; - std::shared_ptr primitiveTReturnValuePtr = std::make_shared(primReturn.release()); + std::shared_ptr primitiveTReturnValuePtr = std::make_shared(primReturn.release()); MS_ASSERT(primitiveTReturnValuePtr != nullptr); inputs.push_back(NewValueNode(primitiveTReturnValuePtr)); inputs.push_back(maketuple_ptr); @@ -562,7 +558,7 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output auto primReturn = std::make_unique(); MS_ASSERT(primReturn != nullptr); primReturn->value.type = schema::PrimitiveType_Return; - std::shared_ptr primitiveTReturnValuePtr = std::make_shared(primReturn.release()); + std::shared_ptr primitiveTReturnValuePtr = std::make_shared(primReturn.release()); MS_ASSERT(primitiveTReturnValuePtr != nullptr); inputs.push_back(NewValueNode(primitiveTReturnValuePtr)); inputs.push_back(cnode_ptr); @@ -658,7 +654,7 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { } onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { - std::unique_ptr onnx_file(new (std::nothrow) char[PATH_MAX]{0}); + std::unique_ptr onnx_file(new (std::nothrow) char[PATH_MAX]{0}); #ifdef _WIN32 if (_fullpath(onnx_file.get(), model_path.c_str(), 1024) == nullptr) { MS_LOG(ERROR) << "open file failed."; @@ -681,7 +677,7 @@ onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string delete onnx_model; return nullptr; } - (void)close(fd); + (void) close(fd); MS_LOG(INFO) << "enter ReadProtoFromBinary success!" << std::endl; return onnx_model; } diff --git a/mindspore/lite/tools/benchmark/benchmark.h b/mindspore/lite/tools/benchmark/benchmark.h index d39cacc82cc..7aeaee55985 100644 --- a/mindspore/lite/tools/benchmark/benchmark.h +++ b/mindspore/lite/tools/benchmark/benchmark.h @@ -26,11 +26,10 @@ #include #include #include +#include "include/model.h" #include "tools/common/flag_parser.h" #include "src/common/file_utils.h" #include "src/common/utils.h" -#include "schema/model_generated.h" -#include "include/model.h" #include "include/lite_session.h" #include "include/inference.h" diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 06a863ca34b..3fdcf20afca 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -1,3 +1,4 @@ +add_definitions(-DPRIMITIVE_WRITEABLE) set(ANF_SRC ${ANF_SRC} #core / abstract diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 60e7f16c909..2501121890c 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -41,9 +41,9 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph) { pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared(true, "conv_relu", schema::PrimitiveType_Activation, - schema::ActivationType_RELU)); + schema::ActivationType_RELU)); pm->AddPass(std::make_shared(true, "conv_relu6", schema::PrimitiveType_Activation, - schema::ActivationType_RELU6)); + schema::ActivationType_RELU6)); pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(old_graph); @@ -51,4 +51,3 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph) { } } // namespace lite } // namespace mindspore - diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index 8538c6cf9bf..60ee2b63b79 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "tools/converter/quantizer/post_training_quantizer.h" #include #include #include @@ -28,7 +29,6 @@ #include "schema/inner/model_generated.h" #include "src/ir/tensor.h" #include "tools/anf_exporter/anf_exporter.h" -#include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quantize_util.h" #include "utils/log_adapter.h" #include "securec/include/securec.h" @@ -208,9 +208,8 @@ struct DivergInfo { } } this->best_T = (static_cast(threshold) + 0.5f) * this->interval; - MS_LOG(DEBUG) << cnode->fullname_with_scope() << " Best threshold bin index: " << threshold - << " T: " << best_T - << " max: " << std::max(fabs(this->max), fabs(this->min)); + MS_LOG(DEBUG) << cnode->fullname_with_scope() << " Best threshold bin index: " << threshold << " T: " << best_T + << " max: " << std::max(fabs(this->max), fabs(this->min)); return RET_OK; } @@ -466,10 +465,10 @@ STATUS Calibrator::ReadConfig() { MS_LOG(WARNING) << "unsupported parameter"; } } - MS_LOG(DEBUG) << "image_path: " << config_param_.image_path << " " - << "batch_count: " << config_param_.batch_count << " " - << "mothod_x: " << config_param_.method_x << " " - << "thread_num: " << config_param_.thread_num; + MS_LOG(DEBUG) << "image_path: " << config_param_.image_path << " " + << "batch_count: " << config_param_.batch_count << " " + << "mothod_x: " << config_param_.method_x << " " + << "thread_num: " << config_param_.thread_num; delete[] resolved_path; fs.close(); @@ -502,7 +501,7 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in } STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct MaxMin *max_min, - std::shared_ptr lite_primitive) { + std::shared_ptr lite_primitive) { if (!lite_primitive->GetInputQuantParams().empty()) { return RET_OK; } @@ -519,7 +518,7 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct M } STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min, - std::shared_ptr lite_primitive) { + std::shared_ptr lite_primitive) { if (!lite_primitive->GetOutputQuantParams().empty()) { return RET_OK; } @@ -535,7 +534,7 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct return RET_OK; } -STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr primitiveT_value, +STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr primitiveT_value, bool perchanel, bool depthwise) { // const vector dims = filter->dims; // perlayer @@ -574,7 +573,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr

primitiveT_value) { +STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr primitiveT_value) { if (primitiveT_value == nullptr || bias == nullptr) { MS_LOG(ERROR) << "null pointer!"; return RET_NULL_PTR; @@ -646,8 +645,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptrtensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t)); + auto ret = memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t)); if (ret != EOK) { MS_LOG(ERROR) << "memcpy_s failed."; delete[] quant_datas; @@ -685,7 +683,7 @@ STATUS PostTrainingQuantizer::QuantNode() { MS_LOG(INFO) << cnode_name << " can not do quant"; continue; } - auto primitiveT_value = GetValueNode>(cnode->input(0)); + auto primitiveT_value = GetValueNode>(cnode->input(0)); if (primitiveT_value == nullptr) { MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; continue; @@ -696,7 +694,7 @@ STATUS PostTrainingQuantizer::QuantNode() { } primitiveT_value->ClearInputOutputQuantParam(); auto op_name = cnode->fullname_with_scope(); - auto op_type = primitiveT_value->GetPrimitiveT()->value.type; + auto op_type = (schema::PrimitiveType)primitiveT_value->Type(); MS_LOG(INFO) << "OpName: " << op_name; if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D && op_type != PrimitiveType_FullConnection) { @@ -724,10 +722,10 @@ STATUS PostTrainingQuantizer::QuantNode() { continue; } auto input_cnode = std::dynamic_pointer_cast(input_node); - auto input_cnode_primitiveT_value = GetValueNode>(input_cnode->input(0)); + auto input_cnode_primitiveT_value = GetValueNode>(input_cnode->input(0)); if (input_cnode_primitiveT_value == nullptr) { MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " - << " PrimitiveTValue is null"; + << " PrimitiveC is null"; continue; } if (!input_cnode_primitiveT_value->GetOutputQuantParams().empty()) { diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h index 587689c5c20..1c47ee4dff7 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h @@ -25,7 +25,6 @@ #include #include "src/lite_session.h" #include "tools/converter/quantizer/quantizer.h" -#include "src/ir/primitive_t_value.h" #include "tools/converter/converter.h" #include "include/ms_tensor.h" @@ -93,13 +92,13 @@ class PostTrainingQuantizer : public Quantizer { // STATUS reformatConvWeight(GraphDefT *graph); - STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr); - STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr); + STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr); + STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr); - STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr primitiveT_value, bool perchannel, + STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr primitiveT_value, bool perchannel, bool depthwise); - STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr primitiveT_value); + STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr primitiveT_value); }; struct DivergInfo; diff --git a/mindspore/lite/tools/converter/quantizer/quant_cast.cc b/mindspore/lite/tools/converter/quantizer/quant_cast.cc index 9a229ba620e..83c94490323 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_cast.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_cast.cc @@ -18,7 +18,7 @@ #include "mindspore/lite/tools/converter/quantizer/quant_cast.h" #include #include -#include "mindspore/lite/src/ir/primitive_t_value.h" +#include "src/ops/primitive_c.h" namespace mindspore::lite::quant { @@ -28,7 +28,7 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector quant_dtype_cast.srcT = src_type; // kNumberTypeInt8; quant_dtype_cast.dstT = dst_type; // kNumberTypeFloat32; primitive->value.Set(quant_dtype_cast); - auto primTValue = std::make_shared(primitive.release()); + auto primTValue = std::make_shared(primitive.release()); primTValue->SetQuantType(schema::QuantType_PostTraining); for (auto &quant_param : quant_params) { std::vector quant_params_in = {quant_param}; @@ -44,7 +44,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { bool first = true; for (auto &cnode : cnodes) { - auto primitiveT_value = GetValueNode>(cnode->input(0)); + auto primitiveT_value = GetValueNode>(cnode->input(0)); auto curnode_quant_type = schema::QuantType_QUANT_NONE; if (primitiveT_value == nullptr) { MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope(); @@ -54,7 +54,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { if (first) { if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { auto value_node = - NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams().front()); + NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams().front()); std::vector op_inputs = {value_node, cnode->input(1)}; auto quant_cast_cnode = graph->NewCNode(op_inputs); quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); @@ -72,10 +72,10 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { continue; } auto input_cnode = std::dynamic_pointer_cast(input_node); - auto input_cnode_primitiveT_value = GetValueNode>(input_cnode->input(0)); + auto input_cnode_primitiveT_value = GetValueNode>(input_cnode->input(0)); if (input_cnode_primitiveT_value == nullptr) { MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " - << " PrimitiveTValue is null"; + << " PrimitiveC is null"; continue; } auto input_cnode_quant_type = input_cnode_primitiveT_value->GetQuantType(); @@ -87,7 +87,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams().front()); } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && - input_cnode_quant_type == schema::QuantType_PostTraining) { + input_cnode_quant_type == schema::QuantType_PostTraining) { value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32, input_cnode_primitiveT_value->GetInputQuantParams().front()); } diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 50bf9fa11f7..884a74daecc 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -18,7 +18,7 @@ #include #include #include -#include "src/ir/primitive_t_value.h" +#include "src/ops/primitive_c.h" #include "mindspore/lite/tools/converter/quantizer/quantize_util.h" #include "mindspore/lite/tools/converter/quantizer/general_bitpacking.h" #include "src/common/utils.h" @@ -32,7 +32,7 @@ namespace mindspore { namespace lite { namespace quant { const std::array QuantStrategy::mConvTypes = { - {"Conv2D", "DeConv2D", "DepthwiseConv2D", "DeDepthwiseConv2D"}}; + {"Conv2D", "DeConv2D", "DepthwiseConv2D", "DeDepthwiseConv2D"}}; const std::array QuantStrategy::mMulTypes = {{"Mul", "MatMul", "BatchMatMul", "FullConnection"}}; QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThreshold) @@ -87,7 +87,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { } auto cnode = std::dynamic_pointer_cast(node); - auto primitiveT_value = GetValueNode>(cnode->input(0)); + auto primitiveT_value = GetValueNode>(cnode->input(0)); if (primitiveT_value == nullptr) { MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope(); return false; @@ -247,7 +247,7 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl } int quantMin = narrowRange ? 1 : 0; - int quantMax = (1 << (unsigned int)numBits) - 1; + int quantMax = (1 << (unsigned int) numBits) - 1; auto quantMinFloat = static_cast(quantMin); auto quantMaxFloat = static_cast(quantMax); double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat); @@ -279,7 +279,7 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl return RET_OK; } -STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primitiveT_value, QuantType quantType, +STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primitiveT_value, QuantType quantType, int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) { auto dims = weight->tensor_shape(); if (per_channel) { @@ -360,7 +360,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr pr } weight->set_tensor_size(elem_count * sizeof(int8_t)); - } else { + } else { // channel at first auto channels = dims[0]; if (channels == 0) { @@ -402,7 +402,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr pr } } auto ret = - memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); + memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); if (ret != EOK) { MS_LOG(ERROR) << "memcpy error: " << ret; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index b38962dfb9a..2823338f81f 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -21,6 +21,7 @@ #include #include #include +#include "src/ops/primitive_c.h" #include "include/errorcode.h" #include "ir/func_graph.h" #include "ir/anf.h" @@ -29,7 +30,6 @@ #include "ir/primitive.h" #include "abstract/dshape.h" #include "mindspore/lite/tools/converter/quantizer/quantizer.h" -#include "mindspore/lite/src/ir/primitive_t_value.h" namespace mindspore { namespace lite { @@ -65,7 +65,7 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange = false, int numBits = UINT8_QUANTIZATION); -template +template T QuantizeData(const float originData, const schema::QuantParamT *quantParam) { MS_ASSERT(quantParam != nullptr); MS_ASSERT(quantParam->inited); @@ -73,7 +73,7 @@ T QuantizeData(const float originData, const schema::QuantParamT *quantParam) { const auto zeroPoint = quantParam->zeroPoint; const auto numBit = quantParam->numBits; const auto narrowRange = quantParam->narrowRange; - const double maxLimit = static_cast((1 << (unsigned int)numBit) - 1 - zeroPoint) * scale; + const double maxLimit = static_cast((1 << (unsigned int) numBit) - 1 - zeroPoint) * scale; double minLimit; if (narrowRange) { minLimit = static_cast(1 - zeroPoint) * scale; @@ -97,7 +97,7 @@ T QuantizeData(const float originData, const schema::QuantParamT *quantParam) { }(); } -template +template T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quant_max, int quant_min) { MS_ASSERT(quantParam != nullptr); MS_ASSERT(quantParam->inited); @@ -118,7 +118,7 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan }(); } -STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primitiveT_value, QuantType quantType, +STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primitiveT_value, QuantType quantType, int quant_max, int quant_min, size_t bitNum = UINT8_QUANTIZATION, bool per_channel = false, bool depth_wise = false); diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index 39d1dc37376..c4c1aace5b8 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -17,7 +17,7 @@ #include #include #include -#include "src/ir/primitive_t_value.h" +#include "src/ops/primitive_c.h" #include "frontend/operator/ops.h" #include "backend/optimizer/common/helper.h" @@ -138,7 +138,7 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) { auto b_prim = b_value->cast(); MS_EXCEPTION_IF_NULL(b_prim); - return a_prim->name() == b_prim->name(); + return a_prim->cast()->Type() == b_prim->cast()->Type(); } else if (a_node->isa() && b_node->isa()) { auto a_value_node_ptr = a_node->cast(); if (a_value_node_ptr == nullptr) { @@ -158,18 +158,18 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) { MS_LOG(EXCEPTION) << "value ptr is nullptr"; } - if (utils::isa(a_value_ptr) && utils::isa(b_value_ptr)) { - auto a_obj = (lite::PrimitiveTValue *) (a_value_ptr.get()); - auto b_obj = (lite::PrimitiveTValue *) (b_value_ptr.get()); + if (utils::isa(a_value_ptr) && utils::isa(b_value_ptr)) { + auto a_obj = (lite::PrimitiveC *) (a_value_ptr.get()); + auto b_obj = (lite::PrimitiveC *) (b_value_ptr.get()); return (*a_obj) == (*b_obj); } else { return (*a_value_ptr) == (*b_value_ptr); } } } - if (a.m_ptr->isa() && b.m_ptr->isa()) { - auto a_value_node_ptr = a.m_ptr->cast(); - auto b_value_node_ptr = b.m_ptr->cast(); + if (a.m_ptr->isa() && b.m_ptr->isa()) { + auto a_value_node_ptr = a.m_ptr->cast(); + auto b_value_node_ptr = b.m_ptr->cast(); return a_value_node_ptr->GetPrimitiveT()->value.type == b_value_node_ptr->GetPrimitiveT()->value.type; } @@ -313,8 +313,8 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) { MS_EXCEPTION_IF_NULL(value_node); auto value = value_node->value(); MS_ASSERT(value != nullptr); - if (utils::isa(value)) { - auto primitive = value->cast(); + if (utils::isa(value)) { + auto primitive = value->cast(); MS_ASSERT(primitive != nullptr); return primitive->GetPrimitiveT()->value.type; } else if (utils::isa(value)) { diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index e970c0babd8..2299779b394 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ #include -#include "src/ir/primitive_t_value.h" +#include "src/ops//primitive_c.h" #include "ir/anf.h" #include "ir/func_graph.h" #include "src/common/utils.h" @@ -26,7 +26,7 @@ #include "schema/inner/model_generated.h" #include "src/param_value_lite.h" -using PrimitiveTValuePtr = std::shared_ptr; +using PrimitiveCPtr = std::shared_ptr; namespace mindspore { namespace opt { bool IsRealCNodeKernel(const AnfNodePtr &node); diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index b2f788513af..1906d4b4c4e 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -25,12 +25,11 @@ #include "src/scheduler.h" #include "include/context.h" #include "src/lite_session.h" -#include "src/ir/primitive_t_value.h" #include "src/populate_parameter.h" #include "src/ops/primitive_c.h" using mindspore::lite::KernelRegistry; -using mindspore::lite::PrimitiveTValue; +using mindspore::lite::PrimitiveC; using mindspore::lite::tensor::Tensor; namespace mindspore::opt { namespace { @@ -45,7 +44,7 @@ const std::vector GetCNodeInputTensors(const CNodePtr &CNode) { auto tensorT = tmp_meta_graph->allTensors.at(input_index).get(); auto tensor_shape = tensorT->dims; auto lite_tensor = - new(std::nothrow)Tensor(TypeId(tensorT->dataType), tensor_shape, tensorT->format, tensorT->nodeType); + new(std::nothrow)Tensor(TypeId(tensorT->dataType), tensor_shape, tensorT->format, tensorT->nodeType); if (lite_tensor == nullptr) { MS_LOG(ERROR) << "lite tensor is nullptr"; return input_tensors; @@ -74,7 +73,7 @@ const std::vector GetCNodeInputTensors(const CNodePtr &CNode) { return input_tensors; } schema::Primitive *PackPrimitiveT(const CNodePtr &cnode) { - auto primitiveT_value = GetValueNode>(cnode->input(0)); + auto primitiveT_value = GetValueNode>(cnode->input(0)); if (primitiveT_value == nullptr) { MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; return nullptr; diff --git a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc index 14d1b3972f7..496b73406c5 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc @@ -16,9 +16,9 @@ #include "tools/optimizer/fusion/conv_activation_fusion.h" #include +#include "mindspore/lite/src/ops/activation.h" +#include "src/ops/primitive_c.h" #include "schema/inner/model_generated.h" -#include "src/ir/primitive_t_value.h" -#include "utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" namespace mindspore::opt { @@ -29,7 +29,7 @@ const BaseRef ConvActivationFusion::DefinePattern() const { auto conv_var = std::make_shared(IsConvNode); auto prim = new schema::PrimitiveT(); prim->value.type = primitive_type; - auto prim_value = std::make_shared(prim); + auto prim_value = std::make_shared(prim); return VectorRef({prim_value, conv_var}); } @@ -44,7 +44,7 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c CheckIfCNodeIsNull(act_node); CheckInputSize(act_node, kActivationInputsLength); - auto act_primitive = GetValueNode>(act_node->input(0)); + auto act_primitive = GetValueNode>(act_node->input(0)); if (act_primitive->GetPrimitiveT()->value.AsActivation()->type != activation_type) { return node; } @@ -56,7 +56,7 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c } auto conv_node = pre_node->cast(); auto node_type = GetCNodeType(conv_node); - auto primitiveT_value = GetValueNode>(conv_node->input(0)); + auto primitiveT_value = GetValueNode>(conv_node->input(0)); MS_ASSERT(primitiveT_value); if (node_type == schema::PrimitiveType_Conv2D) { primitiveT_value->GetPrimitiveT()->value.AsConv2D()->activationType = activation_type; diff --git a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc index afacf8dcb5b..8f24c0fa785 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc @@ -15,9 +15,9 @@ */ #include "tools/optimizer/fusion/conv_biasadd_fusion.h" #include +#include "src/ops/primitive_c.h" #include "src/param_value_lite.h" #include "schema/inner/model_generated.h" -#include "src/ir/primitive_t_value.h" #include "utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "securec/include/securec.h" @@ -53,9 +53,9 @@ int Get_Kenrnel_nums(const CNodePtr &conv_node) { MS_ASSERT(value_node != nullptr); auto value = value_node->value(); MS_ASSERT(value != nullptr); - auto primitive = value->cast(); + auto primitive = value->cast(); MS_ASSERT(primitive != nullptr); - auto type = primitive->GetPrimitiveT()->value.type; + auto type = (schema::PrimitiveType)primitive->Type(); if (type == schema::PrimitiveType_Conv2D) { return primitive->GetPrimitiveT()->value.AsConv2D()->channelOut; } else if (type == schema::PrimitiveType_DepthwiseConv2D) { @@ -149,7 +149,7 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons auto conv_node = conv_node_anf->cast(); CheckIfCNodeIsNull(conv_node); GenConvNewBias(func_graph, conv_node, add_node); - auto primitiveT_value = GetValueNode>(conv_node->input(0)); + auto primitiveT_value = GetValueNode>(conv_node->input(0)); MS_ASSERT(primitiveT_value != nullptr); auto type = primitiveT_value->GetPrimitiveT()->value.type; if (type == schema::PrimitiveType_Conv2D) { diff --git a/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc index 6c511c0dab4..b02eccd3fe9 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc @@ -16,9 +16,9 @@ #include "tools/optimizer/fusion/conv_bn_fusion.h" #include +#include "src/ops/primitive_c.h" #include "src/param_value_lite.h" #include "schema/inner/model_generated.h" -#include "src/ir/primitive_t_value.h" #include "utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "securec/include/securec.h" @@ -113,7 +113,7 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern AnfNodePtr bn_scale_node = nullptr; AnfNodePtr bn_bias_node = nullptr; float eps = 0; - auto primitiveT_value = GetValueNode>(bn_node->input(0)); + auto primitiveT_value = GetValueNode>(bn_node->input(0)); if (GetCNodeType(bn_node) == schema::PrimitiveType_BatchNorm) { bn_mean_node = bn_node->input(kCaffeBNMeanIndex); bn_variance_node = bn_node->input(kCaffeBNVarIndex); diff --git a/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.cc index 10aadeda4c1..8a8ff036f78 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.cc @@ -16,12 +16,10 @@ #include "tools/optimizer/fusion/conv_scale_fusion.h" #include +#include "src/ops/primitive_c.h" #include "src/param_value_lite.h" #include "schema/inner/model_generated.h" -#include "src/ir/primitive_t_value.h" -#include "utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" -#include "include/errorcode.h" #include "securec/include/securec.h" namespace mindspore::opt { diff --git a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc index 969ff385e43..05fc103689c 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc @@ -16,12 +16,10 @@ #include "tools/optimizer/fusion/conv_transform_fusion.h" #include +#include "src/ops/primitive_c.h" #include "src/param_value_lite.h" #include "schema/inner/model_generated.h" -#include "src/ir/primitive_t_value.h" -#include "utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" -#include "include/errorcode.h" #include "securec/include/securec.h" namespace mindspore::opt { @@ -38,14 +36,14 @@ int Get_Kenrnel_nums(const CNodePtr &conv_node) { MS_ASSERT(value_node != nullptr); auto value = value_node->value(); MS_ASSERT(value != nullptr); - auto primitive = value->cast(); + auto primitive = value->cast(); MS_ASSERT(primitive != nullptr); auto type = primitive->GetPrimitiveT()->value.type; if (type == schema::PrimitiveType_Conv2D) { return primitive->GetPrimitiveT()->value.AsConv2D()->channelOut; } else if (type == schema::PrimitiveType_DepthwiseConv2D) { return primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelMultiplier * - primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelIn; + primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelIn; } else { MS_LOG(ERROR) << "Unsupported opType, " << type; return 0; @@ -91,7 +89,7 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias); delete[] trans_bias; delete[] trans_scale; - auto primitiveT_value = GetValueNode>(conv_node->input(0)); + auto primitiveT_value = GetValueNode>(conv_node->input(0)); MS_ASSERT(primitiveT_value != nullptr); auto type = primitiveT_value->GetPrimitiveT()->value.type; if (type == schema::PrimitiveType_Conv2D) { @@ -180,7 +178,7 @@ const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, int kernel_size, const float *trans_scale) const { MS_ASSERT(weight_data != nullptr); - auto tmp_weight_data = new (std::nothrow) float[kernel_num * kernel_size]; + auto tmp_weight_data = new(std::nothrow) float[kernel_num * kernel_size]; MS_ASSERT(new_weight_data != nullptr); auto data_size = kernel_num * kernel_size * sizeof(float); if (0 != memset_s(tmp_weight_data, data_size, 0, data_size)) { diff --git a/mindspore/lite/tools/time_profile/time_profile.h b/mindspore/lite/tools/time_profile/time_profile.h index 6668c5edd60..9dc6d57d4fb 100644 --- a/mindspore/lite/tools/time_profile/time_profile.h +++ b/mindspore/lite/tools/time_profile/time_profile.h @@ -23,14 +23,11 @@ #include #include #include - +#include "include/lite_session.h" #include "tools/common/flag_parser.h" #include "src/common/file_utils.h" #include "src/common/utils.h" -#include "schema/model_generated.h" #include "include/model.h" -#include "include/lite_session.h" - namespace mindspore { namespace lite { @@ -70,7 +67,7 @@ class MS_API TimeProfile { int ReadInputFile(); int InitCallbackParameter(); int InitSession(); - int PrintResult(const std::vector& title, const std::map>& result); + int PrintResult(const std::vector &title, const std::map> &result); private: TimeProfileFlags *_flags;