From 28e35087183988952a76793469f4541314ff74b3 Mon Sep 17 00:00:00 2001 From: hangq Date: Sat, 22 Aug 2020 10:14:25 +0800 Subject: [PATCH] add UnPack method in ops & remove anf_importer populater --- mindspore/lite/CMakeLists.txt | 2 +- mindspore/lite/java/build_aar.sh | 12 +- mindspore/lite/java/native/CMakeLists.txt | 2 +- mindspore/lite/src/common/graph_util.cc | 1 + mindspore/lite/src/executor.cc | 1 + mindspore/lite/src/model.cc | 182 +------- mindspore/lite/src/ops/abs.h | 4 +- mindspore/lite/src/ops/activation.cc | 16 + mindspore/lite/src/ops/activation.h | 5 +- mindspore/lite/src/ops/activation_grad.h | 5 +- mindspore/lite/src/ops/add.cc | 24 ++ mindspore/lite/src/ops/add.h | 6 +- mindspore/lite/src/ops/addn.h | 4 +- mindspore/lite/src/ops/argmax.h | 5 +- mindspore/lite/src/ops/argmin.h | 5 +- mindspore/lite/src/ops/arithmetic.h | 5 +- mindspore/lite/src/ops/arithmetic_self.h | 5 +- mindspore/lite/src/ops/batch_norm.cc | 11 +- mindspore/lite/src/ops/batch_norm.h | 6 +- mindspore/lite/src/ops/batch_to_space.h | 5 +- mindspore/lite/src/ops/bias_add.cc | 11 + mindspore/lite/src/ops/bias_add.h | 6 +- mindspore/lite/src/ops/bias_grad.h | 5 +- mindspore/lite/src/ops/bn_grad_input.h | 5 +- mindspore/lite/src/ops/broadcast_to.h | 5 +- mindspore/lite/src/ops/caffe_p_relu.h | 5 +- mindspore/lite/src/ops/cast.h | 5 +- mindspore/lite/src/ops/ceil.h | 4 +- mindspore/lite/src/ops/clip.h | 5 +- mindspore/lite/src/ops/concat.cc | 12 + mindspore/lite/src/ops/concat.h | 6 +- mindspore/lite/src/ops/constant_of_shape.h | 4 +- mindspore/lite/src/ops/conv2d.cc | 264 ++++++++++++ mindspore/lite/src/ops/conv2d.h | 24 +- mindspore/lite/src/ops/conv2d_grad_filter.h | 5 +- mindspore/lite/src/ops/conv2d_grad_input.h | 5 +- mindspore/lite/src/ops/cos.h | 4 +- mindspore/lite/src/ops/crop.h | 5 +- mindspore/lite/src/ops/deconv2d.h | 5 +- mindspore/lite/src/ops/dedepthwise_conv2d.h | 5 +- mindspore/lite/src/ops/depth_to_space.h | 5 +- mindspore/lite/src/ops/depthwise_conv2d.cc | 168 +++++++- mindspore/lite/src/ops/depthwise_conv2d.h | 21 +- .../ops/dequant.cc} | 28 +- mindspore/lite/src/ops/dequant.h | 38 ++ .../lite/src/ops/detection_post_process.h | 5 +- mindspore/lite/src/ops/div.h | 5 +- mindspore/lite/src/ops/dropout.h | 5 +- mindspore/lite/src/ops/eltwise.h | 5 +- mindspore/lite/src/ops/elu.h | 5 +- mindspore/lite/src/ops/embedding_lookup.h | 5 +- .../lite/src/ops/embedding_lookup_sparse.h | 5 +- mindspore/lite/src/ops/equal.h | 4 +- mindspore/lite/src/ops/exp.h | 4 +- mindspore/lite/src/ops/expand_dims.h | 5 +- .../src/ops/fake_quant_with_min_max_vars.h | 5 +- mindspore/lite/src/ops/fill.h | 5 +- mindspore/lite/src/ops/flatten.cc | 10 + mindspore/lite/src/ops/flatten.h | 7 +- mindspore/lite/src/ops/floor.h | 4 +- mindspore/lite/src/ops/floor_div.h | 4 +- mindspore/lite/src/ops/floor_mod.h | 4 +- mindspore/lite/src/ops/full_connection.h | 5 +- mindspore/lite/src/ops/fused_batchnorm.h | 5 +- mindspore/lite/src/ops/gather.h | 5 +- mindspore/lite/src/ops/gather_nd.h | 5 +- mindspore/lite/src/ops/greater.h | 4 +- mindspore/lite/src/ops/greater_equal.h | 4 +- mindspore/lite/src/ops/l2_norm.h | 5 +- mindspore/lite/src/ops/leaky_relu.h | 5 +- mindspore/lite/src/ops/less.h | 4 +- mindspore/lite/src/ops/less_equal.h | 4 +- .../src/ops/local_response_normalization.h | 5 +- mindspore/lite/src/ops/log.h | 4 +- mindspore/lite/src/ops/logical_and.h | 4 +- mindspore/lite/src/ops/logical_not.h | 4 +- mindspore/lite/src/ops/logical_or.h | 4 +- mindspore/lite/src/ops/lrn.h | 5 +- mindspore/lite/src/ops/lstm.h | 5 +- .../ops/make_tuple.cc} | 31 +- mindspore/lite/src/ops/make_tuple.h | 37 ++ mindspore/lite/src/ops/matmul.cc | 100 +++++ mindspore/lite/src/ops/matmul.h | 19 +- mindspore/lite/src/ops/matrix_diag.h | 5 +- mindspore/lite/src/ops/maximum.h | 4 +- mindspore/lite/src/ops/mean.h | 5 +- mindspore/lite/src/ops/minimum.h | 4 +- mindspore/lite/src/ops/mul.cc | 9 + mindspore/lite/src/ops/mul.h | 7 +- mindspore/lite/src/ops/nchw2nhwc.h | 5 +- mindspore/lite/src/ops/nhwc2nchw.h | 5 +- mindspore/lite/src/ops/not_equal.h | 4 +- mindspore/lite/src/ops/one_hot.h | 5 +- mindspore/lite/src/ops/pad.h | 5 +- mindspore/lite/src/ops/permute.h | 5 +- mindspore/lite/src/ops/pooling.cc | 44 ++ mindspore/lite/src/ops/pooling.h | 7 +- mindspore/lite/src/ops/pooling_grad.h | 5 +- mindspore/lite/src/ops/power.h | 5 +- mindspore/lite/src/ops/power_grad.h | 5 +- mindspore/lite/src/ops/prelu.h | 5 +- mindspore/lite/src/ops/primitive_c.cc | 408 +++++++++++++++--- mindspore/lite/src/ops/primitive_c.h | 37 +- mindspore/lite/src/ops/prior_box.h | 5 +- .../ops/quant.cc} | 28 +- mindspore/lite/src/ops/quant.h | 37 ++ mindspore/lite/src/ops/quant_dtype_cast.h | 6 +- mindspore/lite/src/ops/range.h | 5 +- mindspore/lite/src/ops/rank.h | 5 +- mindspore/lite/src/ops/reduce.cc | 33 ++ mindspore/lite/src/ops/reduce.h | 6 +- mindspore/lite/src/ops/reshape.cc | 27 ++ mindspore/lite/src/ops/reshape.h | 6 +- mindspore/lite/src/ops/resize.h | 5 +- mindspore/lite/src/ops/reverse.h | 5 +- mindspore/lite/src/ops/reverse_sequence.h | 5 +- mindspore/lite/src/ops/roi_pooling.h | 5 +- mindspore/lite/src/ops/round.h | 4 +- mindspore/lite/src/ops/rsqrt.h | 4 +- mindspore/lite/src/ops/scale.h | 5 +- mindspore/lite/src/ops/scatter_nd.h | 5 +- mindspore/lite/src/ops/shape.h | 5 +- mindspore/lite/src/ops/sin.h | 4 +- mindspore/lite/src/ops/slice.cc | 26 +- mindspore/lite/src/ops/slice.h | 9 +- mindspore/lite/src/ops/softmax.h | 5 +- .../lite/src/ops/softmax_cross_entropy.h | 5 +- mindspore/lite/src/ops/space_to_batch.h | 5 +- mindspore/lite/src/ops/space_to_batch_nd.h | 5 +- mindspore/lite/src/ops/space_to_depth.h | 5 +- mindspore/lite/src/ops/sparse_to_dense.h | 5 +- mindspore/lite/src/ops/split.h | 5 +- mindspore/lite/src/ops/sqrt.h | 4 +- mindspore/lite/src/ops/square.h | 4 +- mindspore/lite/src/ops/squared_difference.h | 4 +- mindspore/lite/src/ops/squeeze.h | 5 +- mindspore/lite/src/ops/stack.cc | 4 +- mindspore/lite/src/ops/stack.h | 5 +- mindspore/lite/src/ops/strided_slice.h | 5 +- mindspore/lite/src/ops/sub.h | 5 +- mindspore/lite/src/ops/tile.h | 5 +- mindspore/lite/src/ops/topk.h | 5 +- mindspore/lite/src/ops/transpose.cc | 27 ++ mindspore/lite/src/ops/transpose.h | 6 +- .../ops/tuple_get_item.cc} | 29 +- mindspore/lite/src/ops/tuple_get_item.h | 38 ++ mindspore/lite/src/ops/unique.h | 5 +- mindspore/lite/src/ops/unsqueeze.h | 5 +- mindspore/lite/src/ops/unstack.h | 5 +- mindspore/lite/src/ops/upsample.h | 5 +- mindspore/lite/src/ops/where.h | 5 +- mindspore/lite/src/ops/zeros_like.h | 5 +- mindspore/lite/src/populate_parameter.cc | 16 +- mindspore/lite/src/populate_parameter.h | 5 +- .../lite/tools/anf_exporter/anf_exporter.cc | 64 +-- .../anf_populater/anf_activation_populater.cc | 45 -- .../anf_populater/anf_activation_populater.h | 33 -- .../anf_populater/anf_batchnorm_populater.cc | 36 -- .../anf_populater/anf_batchnorm_populater.h | 31 -- .../anf_populater/anf_biasadd_populater.cc | 37 -- .../anf_populater/anf_biasadd_populater.h | 31 -- .../anf_populater/anf_concat_populater.cc | 40 -- .../anf_populater/anf_conv_populater.cc | 240 ----------- .../anf_populater/anf_conv_populater.h | 99 ----- .../anf_depthwiseconv2d_populater.cc | 195 --------- .../anf_depthwiseconv2d_populater.h | 40 -- .../anf_populater/anf_dequant_populater.cc | 36 -- .../anf_populater/anf_flatten_populater.cc | 36 -- .../anf_populater/anf_flatten_populater.h | 31 -- .../anf_populater/anf_make_tuple_populater.cc | 35 -- .../anf_populater/anf_matmul_populater.cc | 126 ------ .../anf_populater/anf_matmul_populater.h | 39 -- .../anf_populater/anf_mul_populater.cc | 35 -- .../anf_populater/anf_mul_populater.h | 31 -- .../anf_populater/anf_node_populater.cc | 19 - .../anf_populater/anf_node_populater.h | 40 -- .../anf_node_populater_registry.cc | 43 -- .../anf_node_populater_registry.h | 43 -- .../anf_populater/anf_pool_populater.cc | 68 --- .../anf_populater/anf_pool_populater.h | 30 -- .../anf_populater/anf_quant_populater.cc | 36 -- .../anf_populater/anf_reducemean_populater.cc | 58 --- .../anf_populater/anf_reducemean_populater.h | 30 -- .../anf_populater/anf_reshape_populater.cc | 54 --- .../anf_populater/anf_reshape_populater.h | 31 -- .../anf_populater/anf_tensoradd_populater.cc | 35 -- .../anf_populater/anf_tensoradd_populater.h | 30 -- .../anf_populater/anf_transpose_populater.cc | 54 --- .../anf_populater/anf_transpose_populater.h | 30 -- .../anf_tuple_getitem_populater.cc | 35 -- .../anf_tuple_getitem_populater.h | 30 -- .../anf_importer/import_from_meta_graphT.cc | 6 +- .../anf_importer/import_from_protobuf.cc | 61 ++- .../legacy_optimizer/fusion/fusion_pass.cc | 29 +- .../parser/tflite/tflite_model_parser.cc | 2 + .../converter/quantizer/calc_quant_param.cc | 2 +- .../tools/converter/quantizer/quantize_util.h | 16 +- .../tools/converter/quantizer/quantizer.h | 2 +- .../fusion/constant_folding_fusion.cc | 22 +- .../fusion/conv_activation_fusion.cc | 22 +- .../optimizer/fusion/conv_biasadd_fusion.cc | 38 +- .../optimizer/fusion/conv_transform_fusion.cc | 38 +- 202 files changed, 1967 insertions(+), 2555 deletions(-) rename mindspore/lite/{tools/anf_importer/anf_populater/anf_make_tuple_populater.h => src/ops/dequant.cc} (53%) create mode 100644 mindspore/lite/src/ops/dequant.h rename mindspore/lite/{tools/anf_importer/anf_populater/anf_quant_populater.h => src/ops/make_tuple.cc} (55%) create mode 100644 mindspore/lite/src/ops/make_tuple.h rename mindspore/lite/{tools/anf_importer/anf_populater/anf_concat_populater.h => src/ops/quant.cc} (54%) create mode 100644 mindspore/lite/src/ops/quant.h rename mindspore/lite/{tools/anf_importer/anf_populater/anf_dequant_populater.h => src/ops/tuple_get_item.cc} (54%) create mode 100644 mindspore/lite/src/ops/tuple_get_item.h delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.h delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.h delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.h delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.h delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_make_tuple_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.h delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.h delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater.h delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.h delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.h delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.h delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.h delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.h delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.h delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.cc delete mode 100644 mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.h diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index eff4c2613a9..bf1176b4ad0 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -6,7 +6,7 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_ endif () set(MS_VERSION_MAJOY 0) -set(MS_VERSION_MINOR 6) +set(MS_VERSION_MINOR 7) set(MS_VERSION_REVISION 0) set(DIR_PREFIX mindspore-lite) diff --git a/mindspore/lite/java/build_aar.sh b/mindspore/lite/java/build_aar.sh index fd683581e5b..fd6faa86b40 100644 --- a/mindspore/lite/java/build_aar.sh +++ b/mindspore/lite/java/build_aar.sh @@ -5,13 +5,13 @@ BASE_PATH=$(cd "$(dirname $0)"; pwd) TOP_PATH="${BASE_PATH}/../../.." # build mindspore-lite arm64 cd ${TOP_PATH} -#bash build.sh -I arm64 -#COMPILE_RET=$? +bash build.sh -I arm64 +COMPILE_RET=$? -#if [[ "${COMPILE_RET}" -ne 0 ]]; then -# echo "---------------- mindspore lite: build failed ----------------" -# exit -#fi +if [[ "${COMPILE_RET}" -ne 0 ]]; then + echo "---------------- mindspore lite: build failed ----------------" + exit +fi # copy arm64 so cd ${TOP_PATH}/output/ diff --git a/mindspore/lite/java/native/CMakeLists.txt b/mindspore/lite/java/native/CMakeLists.txt index f9336c3a115..e9cb68ef24b 100644 --- a/mindspore/lite/java/native/CMakeLists.txt +++ b/mindspore/lite/java/native/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.14) project (Lite-java) set(MS_VERSION_MAJOY 0) -set(MS_VERSION_MINOR 6) +set(MS_VERSION_MINOR 7) set(MS_VERSION_REVISION 0) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOY=${MS_VERSION_MAJOY} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOY=${MS_VERSION_MAJOY} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}") diff --git a/mindspore/lite/src/common/graph_util.cc b/mindspore/lite/src/common/graph_util.cc index e6f2ae5dfd3..1ff45db63d2 100755 --- a/mindspore/lite/src/common/graph_util.cc +++ b/mindspore/lite/src/common/graph_util.cc @@ -32,6 +32,7 @@ std::vector GetGraphInputNodes(const schema::MetaGraph *meta_graph) { for (size_t j = 0; j < meta_graph->nodes()->size(); j++) { auto *cNode = meta_graph->nodes()->GetAs(j); MS_ASSERT(nullptr != cNode); + MS_ASSERT(nullptr != cNode->inputIndex()); for (size_t k = 0; k < cNode->inputIndex()->size(); k++) { if (cNode->inputIndex()->GetAs(k) == input_index) { if (!IsContain(ret, j)) { diff --git a/mindspore/lite/src/executor.cc b/mindspore/lite/src/executor.cc index 449de53800c..355845f91ca 100644 --- a/mindspore/lite/src/executor.cc +++ b/mindspore/lite/src/executor.cc @@ -53,6 +53,7 @@ int Executor::Run(std::vector &in_tensors, std::vectorname(); return ret; } + if (after != nullptr) { if (!after(PackToMSTensors(kernel->in_tensors()), PackToMSTensors(kernel->out_tensors()), {kernel->name(), kernel->type_str()})) { diff --git a/mindspore/lite/src/model.cc b/mindspore/lite/src/model.cc index 0a2024a90fd..7704c2c8564 100644 --- a/mindspore/lite/src/model.cc +++ b/mindspore/lite/src/model.cc @@ -188,186 +188,6 @@ void ModelImpl::FreeMetaGraph() { const schema::MetaGraph *ModelImpl::meta_graph() const { return this->meta_graph_; } -PrimitiveC *ModelImpl::CopyPrimitive(const schema::Primitive *src_prim) { - MS_EXCEPTION_IF_NULL(src_prim); - auto op_type = src_prim->value_type(); - switch (op_type) { - case schema::PrimitiveType_SoftMax: - return new SoftMax(const_cast(src_prim)); - case schema::PrimitiveType_Activation: - return new Activation(const_cast(src_prim)); - case schema::PrimitiveType_Conv2D: - return new Conv2D(const_cast(src_prim)); - case schema::PrimitiveType_DeConv2D: - return new DeConv2D(const_cast(src_prim)); - case schema::PrimitiveType_Reduce: - return new Reduce(const_cast(src_prim)); - case schema::PrimitiveType_Pooling: - return new Pooling(const_cast(src_prim)); - case schema::PrimitiveType_DepthwiseConv2D: - return new DepthwiseConv2D(const_cast(src_prim)); - case schema::PrimitiveType_FusedBatchNorm: - return new FusedBatchNorm(const_cast(src_prim)); - case schema::PrimitiveType_BatchNorm: - return new BatchNorm(const_cast(src_prim)); - case schema::PrimitiveType_FullConnection: - return new FullConnection(const_cast(src_prim)); - case schema::PrimitiveType_Power: - return new Power(const_cast(src_prim)); - case schema::PrimitiveType_Range: - return new Range(const_cast(src_prim)); - case schema::PrimitiveType_Mul: - return new Mul(const_cast(src_prim)); - case schema::PrimitiveType_Add: - return new Add(const_cast(src_prim)); - case schema::PrimitiveType_Sub: - return new Sub(const_cast(src_prim)); - case schema::PrimitiveType_Div: - return new Div(const_cast(src_prim)); - case schema::PrimitiveType_BiasAdd: - return new BiasAdd(const_cast(src_prim)); - case schema::PrimitiveType_ExpandDims: - return new ExpandDims(const_cast(src_prim)); - case schema::PrimitiveType_ArgMax: - return new ArgMax(const_cast(src_prim)); - case schema::PrimitiveType_ArgMin: - return new ArgMin(const_cast(src_prim)); - case schema::PrimitiveType_Cast: - return new Cast(const_cast(src_prim)); - case schema::PrimitiveType_Reshape: - return new Reshape(const_cast(src_prim)); - case schema::PrimitiveType_Scale: - return new Scale(const_cast(src_prim)); - case schema::PrimitiveType_Eltwise: - return new Eltwise(const_cast(src_prim)); - case schema::PrimitiveType_Concat: - return new Concat(const_cast(src_prim)); - case schema::PrimitiveType_Fill: - return new Fill(const_cast(src_prim)); - case schema::PrimitiveType_Transpose: - return new Transpose(const_cast(src_prim)); - case schema::PrimitiveType_Slice: - return new SliceOp(const_cast(src_prim)); - case schema::PrimitiveType_Squeeze: - return new Squeeze(const_cast(src_prim)); - case schema::PrimitiveType_Nchw2Nhwc: - return new Nchw2Nhwc(const_cast(src_prim)); - case schema::PrimitiveType_Nhwc2Nchw: - return new Nhwc2Nchw(const_cast(src_prim)); - case schema::PrimitiveType_Flatten: - return new Flatten(const_cast(src_prim)); - case schema::PrimitiveType_Mean: - return new Mean(const_cast(src_prim)); - case schema::PrimitiveType_Stack: - return new Stack(const_cast(src_prim)); - case schema::PrimitiveType_Crop: - return new Crop(const_cast(src_prim)); - case schema::PrimitiveType_SquaredDifference: - return new SquaredDifference(const_cast(src_prim)); - case schema::PrimitiveType_AddN: - return new AddN(const_cast(src_prim)); - case schema::PrimitiveType_Abs: - return new Abs(const_cast(src_prim)); - case schema::PrimitiveType_Sin: - return new Sin(const_cast(src_prim)); - case schema::PrimitiveType_Cos: - return new Cos(const_cast(src_prim)); - case schema::PrimitiveType_Log: - return new Log(const_cast(src_prim)); - case schema::PrimitiveType_Sqrt: - return new Sqrt(const_cast(src_prim)); - case schema::PrimitiveType_Rsqrt: - return new Rsqrt(const_cast(src_prim)); - case schema::PrimitiveType_Square: - return new Square(const_cast(src_prim)); - case schema::PrimitiveType_Exp: - return new Exp(const_cast(src_prim)); - case schema::PrimitiveType_Gather: - return new Gather(const_cast(src_prim)); - case schema::PrimitiveType_GatherNd: - return new GatherNd(const_cast(src_prim)); - case schema::PrimitiveType_LocalResponseNormalization: - return new LocalResponseNormalization(const_cast(src_prim)); - case schema::PrimitiveType_Maximum: - return new Maximum(const_cast(src_prim)); - case schema::PrimitiveType_Minimum: - return new Minimum(const_cast(src_prim)); - case schema::PrimitiveType_Pad: - return new Pad(const_cast(src_prim)); - case schema::PrimitiveType_StridedSlice: - return new StridedSlice(const_cast(src_prim)); - case schema::PrimitiveType_Prelu: - return new Prelu(const_cast(src_prim)); - case schema::PrimitiveType_CaffePReLU: - return new CaffePReLU(const_cast(src_prim)); - case schema::PrimitiveType_Round: - return new Round(const_cast(src_prim)); - case schema::PrimitiveType_Reverse: - return new Reverse(const_cast(src_prim)); - case schema::PrimitiveType_ReverseSequence: - return new ReverseSequence(const_cast(src_prim)); - case schema::PrimitiveType_LogicalAnd: - return new LogicalAnd(const_cast(src_prim)); - case schema::PrimitiveType_LogicalOr: - return new LogicalOr(const_cast(src_prim)); - case schema::PrimitiveType_LogicalNot: - return new LogicalNot(const_cast(src_prim)); - case schema::PrimitiveType_FloorDiv: - return new FloorDiv(const_cast(src_prim)); - case schema::PrimitiveType_FloorMod: - return new FloorMod(const_cast(src_prim)); - case schema::PrimitiveType_Equal: - return new Equal(const_cast(src_prim)); - case schema::PrimitiveType_NotEqual: - return new NotEqual(const_cast(src_prim)); - case schema::PrimitiveType_Less: - return new Less(const_cast(src_prim)); - case schema::PrimitiveType_LessEqual: - return new LessEqual(const_cast(src_prim)); - case schema::PrimitiveType_Greater: - return new Greater(const_cast(src_prim)); - case schema::PrimitiveType_GreaterEqual: - return new GreaterEqual(const_cast(src_prim)); - case schema::PrimitiveType_Floor: - return new Floor(const_cast(src_prim)); - case schema::PrimitiveType_Ceil: - return new Ceil(const_cast(src_prim)); - case schema::PrimitiveType_Split: - return new Split(const_cast(src_prim)); - case schema::PrimitiveType_OneHot: - return new OneHot(const_cast(src_prim)); - case schema::PrimitiveType_SpaceToDepth: - return new SpaceToDepth(const_cast(src_prim)); - case schema::PrimitiveType_Tile: - return new Tile(const_cast(src_prim)); - case schema::PrimitiveType_Resize: - return new Resize(const_cast(src_prim)); - case schema::PrimitiveType_Unstack: - return new Unstack(const_cast(src_prim)); - case schema::PrimitiveType_Unique: - return new Unique(const_cast(src_prim)); - case schema::PrimitiveType_TopK: - return new TopK(const_cast(src_prim)); - case schema::PrimitiveType_MatMul: - return new MatMul(const_cast(src_prim)); - case schema::PrimitiveType_QuantDTypeCast: - return new QuantDTypeCast(const_cast(src_prim)); - case schema::PrimitiveType_EmbeddingLookup: - return new EmbeddingLookup(const_cast(src_prim)); - case schema::PrimitiveType_Elu: - return new Elu(const_cast(src_prim)); - case schema::PrimitiveType_DeDepthwiseConv2D: - return new DeDepthwiseConv2D(const_cast(src_prim)); - case schema::PrimitiveType_Shape: - return new Shape(const_cast(src_prim)); - case schema::PrimitiveType_Unsqueeze: - return new Unsqueeze(const_cast(src_prim)); - default: - break; - } - return nullptr; -} - int ModelImpl::BuildOps() { if (this->meta_graph_ == nullptr) { MS_LOG(ERROR) << "mete_graph is nullptr"; @@ -379,7 +199,7 @@ int ModelImpl::BuildOps() { auto name = cNode->name()->str(); auto srcPrim = cNode->primitive(); - this->ops_[name] = CopyPrimitive(srcPrim); + this->ops_[name] = PrimitiveC::UnPackFromSchemaPrimitive(const_cast(srcPrim)); } return 0; } diff --git a/mindspore/lite/src/ops/abs.h b/mindspore/lite/src/ops/abs.h index faac704bd23..92146282049 100644 --- a/mindspore/lite/src/ops/abs.h +++ b/mindspore/lite/src/ops/abs.h @@ -33,9 +33,11 @@ namespace lite { class Abs : public ArithmeticSelf { public: #ifdef PRIMITIVE_WRITEABLE + Abs() = default; explicit Abs(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} -#endif +#else explicit Abs(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/activation.cc b/mindspore/lite/src/ops/activation.cc index 608241cfa52..861ac729050 100644 --- a/mindspore/lite/src/ops/activation.cc +++ b/mindspore/lite/src/ops/activation.cc @@ -15,6 +15,7 @@ */ #include "src/ops/activation.h" +#include namespace mindspore { namespace lite { @@ -25,6 +26,21 @@ float Activation::GetAlpha() const { return this->primitive_->value.AsActivation void Activation::SetType(int type) { this->primitive_->value.AsActivation()->type = (schema::ActivationType)type; } void Activation::SetAlpha(float alpha) { this->primitive_->value.AsActivation()->alpha = alpha; } +int Activation::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + this->primitive_ = new (schema::PrimitiveT); + auto attr = std::make_unique(); + if (prim.name() == "ReLU") { + attr->type = schema::ActivationType_RELU; + } else if (prim.name() == "Sigmoid") { + attr->type = schema::ActivationType_SIGMOID; + } else if (prim.name() == "ReLU6") { + attr->type = schema::ActivationType_RELU6; + } + this->primitive_->value.type = schema::PrimitiveType_Activation; + this->primitive_->value.value = attr.release(); + + return RET_OK; +} #else int Activation::GetType() const { return this->primitive_->value_as_Activation()->type(); } diff --git a/mindspore/lite/src/ops/activation.h b/mindspore/lite/src/ops/activation.h index a4ab25c1d03..fead25192e5 100644 --- a/mindspore/lite/src/ops/activation.h +++ b/mindspore/lite/src/ops/activation.h @@ -27,9 +27,12 @@ namespace lite { class Activation : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Activation() = default; explicit Activation(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif + int UnPackAttr(const Primitive &prim, const std::vector &inputs); +#else explicit Activation(schema::Primitive *primitive) : PrimitiveC(primitive) {} +#endif int GetType() const; float GetAlpha() const; void SetType(int type); diff --git a/mindspore/lite/src/ops/activation_grad.h b/mindspore/lite/src/ops/activation_grad.h index 64e01a90a98..c6199442f80 100644 --- a/mindspore/lite/src/ops/activation_grad.h +++ b/mindspore/lite/src/ops/activation_grad.h @@ -28,10 +28,11 @@ namespace lite { class ActivationGrad : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + ActivationGrad() = default; explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit ActivationGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int GetType() const; void SetType(int type); }; diff --git a/mindspore/lite/src/ops/add.cc b/mindspore/lite/src/ops/add.cc index 9328cbd4e7f..7f934428703 100644 --- a/mindspore/lite/src/ops/add.cc +++ b/mindspore/lite/src/ops/add.cc @@ -15,6 +15,7 @@ */ #include "src/ops/add.h" +#include namespace mindspore { namespace lite { @@ -25,6 +26,29 @@ void Add::SetActivationType(int activation_type) { this->primitive_->value.AsAdd()->activationType = (schema::ActivationType)activation_type; } +int Add::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Add; + } + if (this->primitive_->value.type != schema::PrimitiveType_Add) { + MS_LOG(ERROR) << "Primitive type should be add"; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + this->primitive_->value.value = new (std::nothrow) schema::AddT(); + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + } + return RET_OK; +} + #else int Add::GetActivationType() const { return this->primitive_->value_as_Add()->activationType(); } diff --git a/mindspore/lite/src/ops/add.h b/mindspore/lite/src/ops/add.h index 93c79cccd39..83f58a8ec87 100644 --- a/mindspore/lite/src/ops/add.h +++ b/mindspore/lite/src/ops/add.h @@ -33,10 +33,12 @@ namespace lite { class Add : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE + Add() = default; explicit Add(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} -#endif + int UnPackAttr(const Primitive &prim, const std::vector &inputs); +#else explicit Add(schema::Primitive *primitive) : Arithmetic(primitive) {} - +#endif int GetActivationType() const; void SetActivationType(int activation_type); }; diff --git a/mindspore/lite/src/ops/addn.h b/mindspore/lite/src/ops/addn.h index bd50b7d8502..bf09104b175 100644 --- a/mindspore/lite/src/ops/addn.h +++ b/mindspore/lite/src/ops/addn.h @@ -28,9 +28,11 @@ namespace lite { class AddN : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + AddN() = default; explicit AddN(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit AddN(schema::Primitive *primitive) : PrimitiveC(primitive) {} +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetN() const; void SetN(int n); diff --git a/mindspore/lite/src/ops/argmax.h b/mindspore/lite/src/ops/argmax.h index ed89d1a3a6e..dabca0b3332 100644 --- a/mindspore/lite/src/ops/argmax.h +++ b/mindspore/lite/src/ops/argmax.h @@ -28,10 +28,11 @@ namespace lite { class ArgMax : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + ArgMax() = default; explicit ArgMax(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit ArgMax(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; bool GetOutMaxValue() const; diff --git a/mindspore/lite/src/ops/argmin.h b/mindspore/lite/src/ops/argmin.h index 35b6a3b696f..4d4ae653ef9 100644 --- a/mindspore/lite/src/ops/argmin.h +++ b/mindspore/lite/src/ops/argmin.h @@ -28,10 +28,11 @@ namespace lite { class ArgMin : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + ArgMin() = default; explicit ArgMin(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit ArgMin(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; bool GetOutMaxValue() const; diff --git a/mindspore/lite/src/ops/arithmetic.h b/mindspore/lite/src/ops/arithmetic.h index eee942eb541..fcc0cda3ad6 100644 --- a/mindspore/lite/src/ops/arithmetic.h +++ b/mindspore/lite/src/ops/arithmetic.h @@ -28,10 +28,11 @@ namespace lite { class Arithmetic : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Arithmetic() = default; explicit Arithmetic(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; bool Broadcasting() { return this->broadcasting_; } int NDims() { return this->ndim_; } diff --git a/mindspore/lite/src/ops/arithmetic_self.h b/mindspore/lite/src/ops/arithmetic_self.h index 8ecf5c8899e..d7df543a824 100644 --- a/mindspore/lite/src/ops/arithmetic_self.h +++ b/mindspore/lite/src/ops/arithmetic_self.h @@ -25,10 +25,11 @@ namespace lite { class ArithmeticSelf : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + ArithmeticSelf() = default; explicit ArithmeticSelf(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit ArithmeticSelf(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite diff --git a/mindspore/lite/src/ops/batch_norm.cc b/mindspore/lite/src/ops/batch_norm.cc index dc7e60015bd..3b68b19353b 100644 --- a/mindspore/lite/src/ops/batch_norm.cc +++ b/mindspore/lite/src/ops/batch_norm.cc @@ -15,7 +15,7 @@ */ #include "src/ops/batch_norm.h" - +#include namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -23,6 +23,15 @@ float BatchNorm::GetEpsilon() const { return this->primitive_->value.AsBatchNorm void BatchNorm::SetEpsilon(float epsilon) { this->primitive_->value.AsBatchNorm()->epsilon = epsilon; } +int BatchNorm::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + this->primitive_ = new (schema::PrimitiveT); + auto attr = std::make_unique(); + attr->epsilon = GetValue(prim.GetAttr("epsilon")); + this->primitive_->value.type = schema::PrimitiveType_FusedBatchNorm; + this->primitive_->value.value = attr.release(); + return RET_OK; +} + #else float BatchNorm::GetEpsilon() const { return this->primitive_->value_as_BatchNorm()->epsilon(); } diff --git a/mindspore/lite/src/ops/batch_norm.h b/mindspore/lite/src/ops/batch_norm.h index 4f3a10a9252..2567dddc910 100644 --- a/mindspore/lite/src/ops/batch_norm.h +++ b/mindspore/lite/src/ops/batch_norm.h @@ -28,10 +28,12 @@ namespace lite { class BatchNorm : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + BatchNorm() = default; explicit BatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif + int UnPackAttr(const Primitive &prim, const std::vector &inputs); +#else explicit BatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif float GetEpsilon() const; void SetEpsilon(float epsilon); }; diff --git a/mindspore/lite/src/ops/batch_to_space.h b/mindspore/lite/src/ops/batch_to_space.h index e107803b033..f63205a3df4 100644 --- a/mindspore/lite/src/ops/batch_to_space.h +++ b/mindspore/lite/src/ops/batch_to_space.h @@ -28,10 +28,11 @@ namespace lite { class BatchToSpace : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + BatchToSpace() = default; explicit BatchToSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit BatchToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetBlockShape() const; std::vector GetCrops() const; diff --git a/mindspore/lite/src/ops/bias_add.cc b/mindspore/lite/src/ops/bias_add.cc index 4cada2551cf..6966fd78c74 100644 --- a/mindspore/lite/src/ops/bias_add.cc +++ b/mindspore/lite/src/ops/bias_add.cc @@ -15,6 +15,7 @@ */ #include "src/ops/bias_add.h" +#include namespace mindspore { namespace lite { @@ -23,6 +24,16 @@ std::vector BiasAdd::GetAxis() const { return this->primitive_->value.AsBia void BiasAdd::SetAxis(const std::vector &axis) { this->primitive_->value.AsBiasAdd()->axis = axis; } +int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + this->primitive_ = new (schema::PrimitiveT); + auto attr = std::make_unique(); + attr->axis = {0}; + this->primitive_->value.type = schema::PrimitiveType_BiasAdd; + this->primitive_->value.value = attr.release(); + + return RET_OK; +} + #else std::vector BiasAdd::GetAxis() const { diff --git a/mindspore/lite/src/ops/bias_add.h b/mindspore/lite/src/ops/bias_add.h index 4526d0b3c74..19918fe19cc 100644 --- a/mindspore/lite/src/ops/bias_add.h +++ b/mindspore/lite/src/ops/bias_add.h @@ -28,10 +28,12 @@ namespace lite { class BiasAdd : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + BiasAdd() = default; explicit BiasAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif + int UnPackAttr(const Primitive &prim, const std::vector &inputs); +#else explicit BiasAdd(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif std::vector GetAxis() const; void SetAxis(const std::vector &axis); }; diff --git a/mindspore/lite/src/ops/bias_grad.h b/mindspore/lite/src/ops/bias_grad.h index 9d6cfea64e8..d1525cea1fc 100644 --- a/mindspore/lite/src/ops/bias_grad.h +++ b/mindspore/lite/src/ops/bias_grad.h @@ -28,10 +28,11 @@ namespace lite { class BiasGrad : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + BiasGrad() = default; explicit BiasGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit BiasGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif std::vector GetAxis() const; void SetAxis(const std::vector &axis); }; diff --git a/mindspore/lite/src/ops/bn_grad_input.h b/mindspore/lite/src/ops/bn_grad_input.h index 5a138439300..3e6f0550f33 100644 --- a/mindspore/lite/src/ops/bn_grad_input.h +++ b/mindspore/lite/src/ops/bn_grad_input.h @@ -28,10 +28,11 @@ namespace lite { class BNGradInput : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + BNGradInput() = default; explicit BNGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit BNGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif float GetEps() const; int GetChannels() const; void SetEps(float eps); diff --git a/mindspore/lite/src/ops/broadcast_to.h b/mindspore/lite/src/ops/broadcast_to.h index 2afaf9e1e87..9b3cdaca10d 100644 --- a/mindspore/lite/src/ops/broadcast_to.h +++ b/mindspore/lite/src/ops/broadcast_to.h @@ -28,10 +28,11 @@ namespace lite { class BroadcastTo : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + BroadcastTo() = default; explicit BroadcastTo(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit BroadcastTo(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetDstShape() const; void SetDstShape(const std::vector &dst_shape); diff --git a/mindspore/lite/src/ops/caffe_p_relu.h b/mindspore/lite/src/ops/caffe_p_relu.h index 92e7bf1feba..d78c456338d 100644 --- a/mindspore/lite/src/ops/caffe_p_relu.h +++ b/mindspore/lite/src/ops/caffe_p_relu.h @@ -28,10 +28,11 @@ namespace lite { class CaffePReLU : public Activation { public: #ifdef PRIMITIVE_WRITEABLE + CaffePReLU() = default; explicit CaffePReLU(schema::PrimitiveT *primitive) : Activation(primitive) {} -#endif +#else explicit CaffePReLU(schema::Primitive *primitive) : Activation(primitive) {} - +#endif bool GetChannelShared() const; void SetChannelShared(bool channel_shared); }; diff --git a/mindspore/lite/src/ops/cast.h b/mindspore/lite/src/ops/cast.h index 45ceb800495..6a244611d07 100644 --- a/mindspore/lite/src/ops/cast.h +++ b/mindspore/lite/src/ops/cast.h @@ -28,10 +28,11 @@ namespace lite { class Cast : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Cast() = default; explicit Cast(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Cast(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetSrcT() const; int GetDstT() const; diff --git a/mindspore/lite/src/ops/ceil.h b/mindspore/lite/src/ops/ceil.h index 43dbc344ac5..5ce5276fb73 100644 --- a/mindspore/lite/src/ops/ceil.h +++ b/mindspore/lite/src/ops/ceil.h @@ -28,9 +28,11 @@ namespace lite { class Ceil : public ArithmeticSelf { public: #ifdef PRIMITIVE_WRITEABLE + Ceil() = default; explicit Ceil(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} -#endif +#else explicit Ceil(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/clip.h b/mindspore/lite/src/ops/clip.h index 2cda7fa5129..3f8289840e8 100644 --- a/mindspore/lite/src/ops/clip.h +++ b/mindspore/lite/src/ops/clip.h @@ -28,10 +28,11 @@ namespace lite { class Clip : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Clip() = default; explicit Clip(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Clip(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif float GetMax() const; float GetMin() const; void SetMax(float max); diff --git a/mindspore/lite/src/ops/concat.cc b/mindspore/lite/src/ops/concat.cc index 5f087ea7a6b..7b53023410d 100644 --- a/mindspore/lite/src/ops/concat.cc +++ b/mindspore/lite/src/ops/concat.cc @@ -15,9 +15,11 @@ */ #include "src/ops/concat.h" +#include #include "include/errorcode.h" #include "utils/log_adapter.h" #include "src/ir/tensor.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -27,6 +29,16 @@ int Concat::GetN() const { return this->primitive_->value.AsConcat()->n; } void Concat::SetAxis(int axis) { this->primitive_->value.AsConcat()->axis = axis; } void Concat::SetN(int n) { this->primitive_->value.AsConcat()->n = n; } +int Concat::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + this->primitive_ = new (schema::PrimitiveT); + auto attr = std::make_unique(); + auto prim_axis = GetValue(prim.GetAttr("axis")); + attr->axis = prim_axis; + this->primitive_->value.type = schema::PrimitiveType_Concat; + this->primitive_->value.value = attr.release(); + return RET_OK; +} + #else int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); } diff --git a/mindspore/lite/src/ops/concat.h b/mindspore/lite/src/ops/concat.h index 189b1387aac..5f2099e6532 100644 --- a/mindspore/lite/src/ops/concat.h +++ b/mindspore/lite/src/ops/concat.h @@ -28,10 +28,12 @@ namespace lite { class Concat : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Concat() = default; explicit Concat(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif + int UnPackAttr(const Primitive &prim, const std::vector &inputs); +#else explicit Concat(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; int GetN() const; diff --git a/mindspore/lite/src/ops/constant_of_shape.h b/mindspore/lite/src/ops/constant_of_shape.h index 34e1791efbc..f9bb1d6581b 100644 --- a/mindspore/lite/src/ops/constant_of_shape.h +++ b/mindspore/lite/src/ops/constant_of_shape.h @@ -28,9 +28,11 @@ namespace lite { class ConstantOfShape : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + ConstantOfShape() = default; explicit ConstantOfShape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit ConstantOfShape(schema::Primitive *primitive) : PrimitiveC(primitive) {} +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; float GetValue() const; void SetValue(float value); diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index 0a0baf98d7b..268c15f9077 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -15,9 +15,14 @@ */ #include "src/ops/conv2d.h" +#include +#include #include "include/errorcode.h" #include "utils/log_adapter.h" #include "src/ir/tensor.h" +#ifdef PRIMITIVE_WRITEABLE +#include "tools/converter/quantizer/quantize_util.h" +#endif namespace mindspore { namespace lite { @@ -63,6 +68,265 @@ void Conv2D::SetHasBias(bool has_bias) { this->primitive_->value.AsConv2D()->has void Conv2D::SetActivationType(int activation_type) { this->primitive_->value.AsConv2D()->activationType = (schema::ActivationType)activation_type; } +template +void ConvertConvWeight(const ParameterPtr ¶m_node) { + MS_ASSERT(param_node != nullptr); + auto param = param_node->default_param(); + auto weight = std::dynamic_pointer_cast(param); + MS_ASSERT(weight != nullptr); + + std::unique_ptr buf(new (std::nothrow) T[weight->tensor_shape_size()]); + if (buf == nullptr) { + MS_LOG(ERROR) << "new buf failed"; + return; + } + + size_t filter_k = weight->tensor_shape()[0]; + size_t filter_c = weight->tensor_shape()[1]; + size_t filter_h = weight->tensor_shape()[2]; + size_t filter_w = weight->tensor_shape()[3]; + T *p1Buff = nullptr; + T *p2Buff = nullptr; + for (size_t k = 0; k < filter_k; ++k) { + for (size_t c = 0; c < filter_c; ++c) { + for (size_t h = 0; h < filter_h; ++h) { + for (size_t w = 0; w < filter_w; ++w) { + p1Buff = reinterpret_cast(weight->tensor_addr()) + + ((k * filter_c * filter_h * filter_w) + (c * filter_h * filter_w) + (h * filter_w) + (w)); + p2Buff = + buf.get() + ((c * filter_k * filter_h * filter_w) + (k * filter_h * filter_w) + (h * filter_w) + (w)); + *p2Buff = *p1Buff; + } + } + } + } + + auto ret = ::memcpy_s(weight->tensor_addr(), weight->tensor_shape_size() * sizeof(T), buf.get(), + weight->tensor_shape_size() * sizeof(T)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s failed: " << ret; + return; + } + + auto abstract_base = param_node->abstract(); + MS_ASSERT(abstract_base != nullptr); + if (utils::isa(abstract_base)) { + auto abstract_tensor = utils::cast(abstract_base); + utils::cast(abstract_tensor->BuildShape())->shape()[0] = filter_c; + utils::cast(abstract_tensor->BuildShape())->shape()[1] = filter_k; + utils::cast(abstract_tensor->BuildShape())->shape()[2] = filter_h; + utils::cast(abstract_tensor->BuildShape())->shape()[3] = filter_w; + } + return; +} +void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, + const std::vector &inputs) { + auto attr = std::make_unique(); + auto format = GetValue(prim.GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + auto pad_list = GetValue>(prim.GetAttr("pad_list")); + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; + attr->padRight = pad_list[3]; + + auto dilation = GetValue>(prim.GetAttr("dilation")); + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + auto kernel_size = GetValue>(prim.GetAttr("kernel_size")); + attr->kernelH = kernel_size[0]; + attr->kernelW = kernel_size[1]; + + auto stride = GetValue>(prim.GetAttr("stride")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + auto pad_mode = GetValue(prim.GetAttr("pad_mode")); + if (pad_mode == "valid") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "same") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + + int channel_mutiplier = 1; + if (prim.GetAttr("channel_mutiplier") != nullptr) { + channel_mutiplier = GetValue(prim.GetAttr("channel_multiplier")); + } + attr->channelMultiplier = channel_mutiplier; + + MS_ASSERT(inputs.size() == kAnfPopulaterTwo); + auto input_node = inputs[kAnfPopulaterOne]; + MS_ASSERT(input_node != nullptr); + if (input_node->isa()) { + auto param_node = input_node->cast(); + ConvertConvWeight(param_node); + } + + primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; + primitive->value.value = attr.release(); +} + +void Conv2D::PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group) { + auto attr = std::make_unique(); + attr->group = group; + auto format = GetValue(prim.GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + auto pad_list = GetValue>(prim.GetAttr("pad_list")); + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; + attr->padRight = pad_list[3]; + + auto dilation = GetValue>(prim.GetAttr("dilation")); + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + auto kernel_size = GetValue>(prim.GetAttr("kernel_size")); + attr->kernelH = kernel_size[0]; + attr->kernelW = kernel_size[1]; + + auto stride = GetValue>(prim.GetAttr("stride")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + attr->channelOut = GetValue(prim.GetAttr("out_channel")); + + auto pad_mode = GetValue(prim.GetAttr("pad_mode")); + if (pad_mode == "valid") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "same") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + primitive->value.type = schema::PrimitiveType_Conv2D; + primitive->value.value = attr.release(); +} + +void Conv2D::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { + constexpr float qmin = 0; + constexpr float qmax = 255; + *mMin = static_cast((qmin - mean) / stdDev); + *mMax = static_cast((qmax - mean) / stdDev); +} + +void Conv2D::PopulaterQuantParam(const Primitive &prim, + std::vector> *vecInputQuantParam, + std::vector> *vecOutputQuantParam) { + auto narrow_range = prim.GetAttr("narrow_range"); + bool narrowRangeQuantParam = GetValue(narrow_range); + auto num_bits = prim.GetAttr("num_bits"); + int32_t numbitsRangeQuantParam = GetValue(num_bits); + + std::vector quants; + schema::QuantParamT quantParam; + auto mean = prim.GetAttr("mean"); + auto std_dev = prim.GetAttr("std_dev"); + if (mean != nullptr && std_dev != nullptr) { + auto meanQuantOaram = GetValue(mean); + double stddevQuantOaram = GetValue(std_dev); + float mMin = 0.0; + float mMax = 0.0; + CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); + quantParam.min = mMin; + quantParam.max = mMax; + } else { + auto inputMin = prim.GetAttr("input_minq"); + auto inputMax = prim.GetAttr("input_maxq"); + auto inputMinPtr = inputMin->cast(); + auto inputMaxPtr = inputMax->cast(); + float *minBuf = static_cast(inputMinPtr->Data()); + float *maxBuf = static_cast(inputMaxPtr->Data()); + quantParam.min = *minBuf; + quantParam.max = *maxBuf; + } + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, + numbitsRangeQuantParam); + quants.emplace_back(quantParam); + vecInputQuantParam->emplace_back(quants); + + quants.clear(); + int biasQuantSize = 0; + auto filterMin = prim.GetAttr("filter_minq"); + auto filterMax = prim.GetAttr("filter_maxq"); + if (filterMin != nullptr && filterMax != nullptr) { + auto filterMinPtr = filterMin->cast(); + auto filterMaxPtr = filterMax->cast(); + float *minBuf = static_cast(filterMinPtr->Data()); + float *maxBuf = static_cast(filterMaxPtr->Data()); + biasQuantSize = filterMinPtr->DataSize(); + for (int i = 0; i < biasQuantSize; ++i) { + quantParam.min = *(minBuf++); + quantParam.max = *(maxBuf++); + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, + numbitsRangeQuantParam); + quants.emplace_back(quantParam); + } + vecInputQuantParam->emplace_back(quants); + } + + quants.clear(); + for (int i = 0; i < biasQuantSize; ++i) { + quantParam.min = 0.0; + quantParam.max = 0.0; + quantParam.zeroPoint = 0; + + quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale; + quants.emplace_back(quantParam); + } + vecInputQuantParam->emplace_back(quants); + + quants.clear(); + auto outputMin = prim.GetAttr("output_minq"); + auto outputMax = prim.GetAttr("output_maxq"); + if (outputMin != nullptr && outputMax != nullptr) { + auto outputMinPtr = outputMin->cast(); + auto outputMaxPtr = outputMax->cast(); + float *minBuf = static_cast(outputMinPtr->Data()); + float *maxBuf = static_cast(outputMaxPtr->Data()); + quantParam.min = *minBuf; + quantParam.max = *maxBuf; + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, + numbitsRangeQuantParam); + quants.emplace_back(quantParam); + vecOutputQuantParam->emplace_back(quants); + } +} + +int Conv2D::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + this->primitive_ = new (schema::PrimitiveT); + + int group = GetValue(prim.GetAttr("group")); + if (group > 1) { + PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); + } else { + PopulaterConv2DSingleGroup(prim, this->primitive_, group); + } + + if (GetQuantType() == schema::QuantType_AwareTraining) { + std::vector> vecInputQuantParam; + std::vector> vecOutputQuantParam; + PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); + SetInputQuantParam(vecInputQuantParam); + SetOutputQuantParam(vecOutputQuantParam); + } + return RET_OK; +} #else diff --git a/mindspore/lite/src/ops/conv2d.h b/mindspore/lite/src/ops/conv2d.h index 2a1b718f3e8..4c769a952b6 100644 --- a/mindspore/lite/src/ops/conv2d.h +++ b/mindspore/lite/src/ops/conv2d.h @@ -20,17 +20,35 @@ #include #include #include -#include "ir/dtype/type_id.h" +#include #include "src/ops/primitive_c.h" +#include "ir/dtype/type_id.h" namespace mindspore { namespace lite { class Conv2D : public PrimitiveC { - public: #ifdef PRIMITIVE_WRITEABLE + + public: + Conv2D() = default; explicit Conv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif + + int UnPackAttr(const Primitive &prim, const std::vector &inputs); + + private: + void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, + const std::vector &inputs); + void PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group); + void PopulaterQuantParam(const Primitive &prim, std::vector> *vecInputQuantParam, + std::vector> *vecOutputQuantParam); + void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); +#else + + public: explicit Conv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {} +#endif + + public: int InferShape(std::vector inputs_, std::vector outputs_) override; int PadUp() const; int PadDown() const; diff --git a/mindspore/lite/src/ops/conv2d_grad_filter.h b/mindspore/lite/src/ops/conv2d_grad_filter.h index 7ed1d696611..5e342e6ee42 100644 --- a/mindspore/lite/src/ops/conv2d_grad_filter.h +++ b/mindspore/lite/src/ops/conv2d_grad_filter.h @@ -28,10 +28,11 @@ namespace lite { class Conv2DGradFilter : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Conv2DGradFilter() = default; explicit Conv2DGradFilter(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Conv2DGradFilter(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int GetFormat() const; int GetGroup() const; int GetChannelIn() const; diff --git a/mindspore/lite/src/ops/conv2d_grad_input.h b/mindspore/lite/src/ops/conv2d_grad_input.h index 7c71d6d2095..ce923323181 100644 --- a/mindspore/lite/src/ops/conv2d_grad_input.h +++ b/mindspore/lite/src/ops/conv2d_grad_input.h @@ -28,10 +28,11 @@ namespace lite { class Conv2DGradInput : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Conv2DGradInput() = default; explicit Conv2DGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Conv2DGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int GetFormat() const; int GetGroup() const; int GetChannelIn() const; diff --git a/mindspore/lite/src/ops/cos.h b/mindspore/lite/src/ops/cos.h index 1cc39df284a..4675c92eb53 100644 --- a/mindspore/lite/src/ops/cos.h +++ b/mindspore/lite/src/ops/cos.h @@ -28,9 +28,11 @@ namespace lite { class Cos : public ArithmeticSelf { public: #ifdef PRIMITIVE_WRITEABLE + Cos() = default; explicit Cos(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} -#endif +#else explicit Cos(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/crop.h b/mindspore/lite/src/ops/crop.h index c47402c043b..88d6276fd7e 100644 --- a/mindspore/lite/src/ops/crop.h +++ b/mindspore/lite/src/ops/crop.h @@ -28,10 +28,11 @@ namespace lite { class Crop : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Crop() = default; explicit Crop(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Crop(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; long GetAxis() const; std::vector GetOffsets() const; diff --git a/mindspore/lite/src/ops/deconv2d.h b/mindspore/lite/src/ops/deconv2d.h index 0fe42927f00..5fbf25ab436 100644 --- a/mindspore/lite/src/ops/deconv2d.h +++ b/mindspore/lite/src/ops/deconv2d.h @@ -28,10 +28,11 @@ namespace lite { class DeConv2D : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + DeConv2D() = default; explicit DeConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit DeConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; int GetGroup() const; diff --git a/mindspore/lite/src/ops/dedepthwise_conv2d.h b/mindspore/lite/src/ops/dedepthwise_conv2d.h index 689d2033bf3..f285b769645 100644 --- a/mindspore/lite/src/ops/dedepthwise_conv2d.h +++ b/mindspore/lite/src/ops/dedepthwise_conv2d.h @@ -28,10 +28,11 @@ namespace lite { class DeDepthwiseConv2D : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + DeDepthwiseConv2D() = default; explicit DeDepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit DeDepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; int GetChannelIn() const; diff --git a/mindspore/lite/src/ops/depth_to_space.h b/mindspore/lite/src/ops/depth_to_space.h index 3ae52550eef..4b02e6e948c 100644 --- a/mindspore/lite/src/ops/depth_to_space.h +++ b/mindspore/lite/src/ops/depth_to_space.h @@ -28,10 +28,11 @@ namespace lite { class DepthToSpace : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + DepthToSpace() = default; explicit DepthToSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit DepthToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetBlockSize() const; int GetFormat() const; diff --git a/mindspore/lite/src/ops/depthwise_conv2d.cc b/mindspore/lite/src/ops/depthwise_conv2d.cc index 3903ce83493..5c1e8d4f246 100644 --- a/mindspore/lite/src/ops/depthwise_conv2d.cc +++ b/mindspore/lite/src/ops/depthwise_conv2d.cc @@ -15,7 +15,11 @@ */ #include "src/ops/depthwise_conv2d.h" - +#include +#include +#ifdef PRIMITIVE_WRITEABLE +#include "tools/converter/quantizer/quantize_util.h" +#endif namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -65,6 +69,168 @@ void DepthwiseConv2D::SetActivationType(int activation_type) { this->primitive_->value.AsDepthwiseConv2D()->activationType = (schema::ActivationType)activation_type; } +void DepthwiseConv2D::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { + constexpr float qmin = 0; + constexpr float qmax = 255; + *mMin = static_cast((qmin - mean) / stdDev); + *mMax = static_cast((qmax - mean) / stdDev); +} + +void DepthwiseConv2D::PopulaterQuantParam(const Primitive &prim, + std::vector> *vecInputQuantParam, + std::vector> *vecOutputQuantParam) { + auto narrow_range = prim.GetAttr("narrow_range"); + bool narrowRangeQuantParam = GetValue(narrow_range); + auto num_bits = prim.GetAttr("num_bits"); + int32_t numbitsRangeQuantParam = GetValue(num_bits); + + std::vector quants; + schema::QuantParamT quantParam; + auto mean = prim.GetAttr("mean"); + auto std_dev = prim.GetAttr("std_dev"); + if (mean != nullptr && std_dev != nullptr) { + auto meanQuantOaram = GetValue(mean); + double stddevQuantOaram = GetValue(std_dev); + float mMin = 0.0; + float mMax = 0.0; + CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); + quantParam.min = mMin; + quantParam.max = mMax; + } else { + auto inputMin = prim.GetAttr("input_minq"); + auto inputMax = prim.GetAttr("input_maxq"); + auto inputMinPtr = inputMin->cast(); + auto inputMaxPtr = inputMax->cast(); + float *minBuf = static_cast(inputMinPtr->Data()); + float *maxBuf = static_cast(inputMaxPtr->Data()); + quantParam.min = *minBuf; + quantParam.max = *maxBuf; + } + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, + numbitsRangeQuantParam); + quants.emplace_back(quantParam); + vecInputQuantParam->emplace_back(quants); + + quants.clear(); + int biasQuantSize = 0; + auto filterMin = prim.GetAttr("filter_minq"); + auto filterMax = prim.GetAttr("filter_maxq"); + if (filterMin != nullptr && filterMax != nullptr) { + auto filterMinPtr = filterMin->cast(); + auto filterMaxPtr = filterMax->cast(); + float *minBuf = static_cast(filterMinPtr->Data()); + float *maxBuf = static_cast(filterMaxPtr->Data()); + biasQuantSize = filterMinPtr->DataSize(); + for (int i = 0; i < biasQuantSize; ++i) { + quantParam.min = *(minBuf++); + quantParam.max = *(maxBuf++); + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, + numbitsRangeQuantParam); + quants.emplace_back(quantParam); + } + vecInputQuantParam->emplace_back(quants); + } + + quants.clear(); + for (int i = 0; i < biasQuantSize; ++i) { + quantParam.min = 0.0; + quantParam.max = 0.0; + quantParam.zeroPoint = 0; + + quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale; + quants.emplace_back(quantParam); + } + vecInputQuantParam->emplace_back(quants); + + quants.clear(); + auto outputMin = prim.GetAttr("output_minq"); + auto outputMax = prim.GetAttr("output_maxq"); + if (outputMin != nullptr && outputMax != nullptr) { + auto outputMinPtr = outputMin->cast(); + auto outputMaxPtr = outputMax->cast(); + float *minBuf = static_cast(outputMinPtr->Data()); + float *maxBuf = static_cast(outputMaxPtr->Data()); + quantParam.min = *minBuf; + quantParam.max = *maxBuf; + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, + numbitsRangeQuantParam); + quants.emplace_back(quantParam); + vecOutputQuantParam->emplace_back(quants); + } +} + +int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + this->primitive_ = new (schema::PrimitiveT); + auto attr = std::make_unique(); + + auto format = GetValue(prim.GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + auto pad_list = GetValue>(prim.GetAttr("pads")); + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; + attr->padRight = pad_list[3]; + + auto dilation = GetValue>(prim.GetAttr("dilation")); + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + auto kernel_size = GetValue>(prim.GetAttr("kernel_size")); + attr->kernelH = kernel_size[0]; + attr->kernelW = kernel_size[1]; + + auto stride = GetValue>(prim.GetAttr("stride")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + auto pad_mode = GetValue(prim.GetAttr("pad_mode")); + if (pad_mode == "valid") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "same") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + + auto channel_multiplier = GetValue(prim.GetAttr("channel_multiplier")); + attr->channelMultiplier = channel_multiplier; + + MS_ASSERT(inputs.size() == kAnfPopulaterTwo); + auto inputNode = inputs[kAnfPopulaterOne]; + MS_ASSERT(inputNode != nullptr); + if (inputNode->isa()) { + auto paramNode = inputNode->cast(); + auto abstractBase = paramNode->abstract(); + MS_ASSERT(abstractBase != nullptr); + if (utils::isa(abstractBase)) { + auto abstractTensor = utils::cast(abstractBase); + MS_ASSERT(abstractTensor != nullptr); + if (utils::isa(abstractTensor->BuildShape())) { + auto dims = utils::cast(abstractTensor->BuildShape())->shape(); + attr->channelIn = dims[kAnfPopulaterOne]; + } + } + } + + this->primitive_->value.type = schema::PrimitiveType_DepthwiseConv2D; + this->primitive_->value.value = attr.release(); + + if (GetQuantType() == schema::QuantType_AwareTraining) { + std::vector> vecInputQuantParam; + std::vector> vecOutputQuantParam; + PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); + SetInputQuantParam(vecInputQuantParam); + SetOutputQuantParam(vecOutputQuantParam); + } + return RET_OK; +} + #else int DepthwiseConv2D::GetFormat() const { return this->primitive_->value_as_DepthwiseConv2D()->format(); } diff --git a/mindspore/lite/src/ops/depthwise_conv2d.h b/mindspore/lite/src/ops/depthwise_conv2d.h index eb60575341c..ea92d565af4 100644 --- a/mindspore/lite/src/ops/depthwise_conv2d.h +++ b/mindspore/lite/src/ops/depthwise_conv2d.h @@ -26,12 +26,25 @@ namespace mindspore { namespace lite { class DepthwiseConv2D : public PrimitiveC { - public: #ifdef PRIMITIVE_WRITEABLE - explicit DepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif - explicit DepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {} + public: + DepthwiseConv2D() = default; + explicit DepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + + int UnPackAttr(const Primitive &prim, const std::vector &inputs); + + private: + void PopulaterQuantParam(const Primitive &prim, std::vector> *vecInputQuantParam, + std::vector> *vecOutputQuantParam); + void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); +#else + + public: + explicit DepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {} +#endif + + public: int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; int GetChannelIn() const; diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_make_tuple_populater.h b/mindspore/lite/src/ops/dequant.cc similarity index 53% rename from mindspore/lite/tools/anf_importer/anf_populater/anf_make_tuple_populater.h rename to mindspore/lite/src/ops/dequant.cc index d8b283e41ec..eabb28124c9 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_make_tuple_populater.h +++ b/mindspore/lite/src/ops/dequant.cc @@ -13,18 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_ANF_MAKE_TUPLE_PARSER_H -#define MINDSPORE_ANF_MAKE_TUPLE_PARSER_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" +#include "src/ops/dequant.h" #include +#include -namespace mindspore::lite { -class AnfMakeTuplePopulater : public AnfNodePopulater { - public: - AnfMakeTuplePopulater() = default; - ~AnfMakeTuplePopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) override; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_ANF_MAKE_TUPLE_PARSER_H +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int Dequant::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + this->primitive_ = new (schema::PrimitiveT); + auto attr = std::make_unique(); + this->primitive_->value.type = schema::PrimitiveType_OnnxInt8Dequantize; + this->primitive_->value.value = attr.release(); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/dequant.h b/mindspore/lite/src/ops/dequant.h new file mode 100644 index 00000000000..d9553177a2e --- /dev/null +++ b/mindspore/lite/src/ops/dequant.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_SRC_OPS_DEQUANT_H_ +#define LITE_MINDSPORE_LITE_SRC_OPS_DEQUANT_H_ + +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class Dequant : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + Dequant() = default; + explicit Dequant(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs); +#else + explicit Dequant(schema::Primitive *primitive) : PrimitiveC(primitive) {} +#endif +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_SRC_OPS_DEQUANT_H_ diff --git a/mindspore/lite/src/ops/detection_post_process.h b/mindspore/lite/src/ops/detection_post_process.h index ab8624f0bf5..4d8601c2c99 100644 --- a/mindspore/lite/src/ops/detection_post_process.h +++ b/mindspore/lite/src/ops/detection_post_process.h @@ -28,10 +28,11 @@ namespace lite { class DetectionPostProcess : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + DetectionPostProcess() = default; explicit DetectionPostProcess(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit DetectionPostProcess(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int GetFormat() const; int GetInputSize() const; float GetHScale() const; diff --git a/mindspore/lite/src/ops/div.h b/mindspore/lite/src/ops/div.h index 043275cfcde..5f10b2bc236 100644 --- a/mindspore/lite/src/ops/div.h +++ b/mindspore/lite/src/ops/div.h @@ -28,10 +28,11 @@ namespace lite { class Div : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE + Div() = default; explicit Div(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} -#endif +#else explicit Div(schema::Primitive *primitive) : Arithmetic(primitive) {} - +#endif int GetActivationType() const; void SetActivationType(int activation_type); }; diff --git a/mindspore/lite/src/ops/dropout.h b/mindspore/lite/src/ops/dropout.h index 3804ad16757..4a99d07a81e 100644 --- a/mindspore/lite/src/ops/dropout.h +++ b/mindspore/lite/src/ops/dropout.h @@ -28,10 +28,11 @@ namespace lite { class Dropout : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Dropout() = default; explicit Dropout(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Dropout(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif float GetRatio() const; void SetRatio(float ratio); }; diff --git a/mindspore/lite/src/ops/eltwise.h b/mindspore/lite/src/ops/eltwise.h index 0d24bb550d8..ba28b1f5339 100644 --- a/mindspore/lite/src/ops/eltwise.h +++ b/mindspore/lite/src/ops/eltwise.h @@ -28,10 +28,11 @@ namespace lite { class Eltwise : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Eltwise() = default; explicit Eltwise(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Eltwise(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int GetMode() const; void SetMode(int mode); }; diff --git a/mindspore/lite/src/ops/elu.h b/mindspore/lite/src/ops/elu.h index 814ff4b4e77..2c08ea817df 100644 --- a/mindspore/lite/src/ops/elu.h +++ b/mindspore/lite/src/ops/elu.h @@ -28,10 +28,11 @@ namespace lite { class Elu : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Elu() = default; explicit Elu(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Elu(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif float GetAlpha() const; void SetAlpha(float alpha); }; diff --git a/mindspore/lite/src/ops/embedding_lookup.h b/mindspore/lite/src/ops/embedding_lookup.h index a0f42137942..82aa70f12eb 100644 --- a/mindspore/lite/src/ops/embedding_lookup.h +++ b/mindspore/lite/src/ops/embedding_lookup.h @@ -28,10 +28,11 @@ namespace lite { class EmbeddingLookup : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + EmbeddingLookup() = default; explicit EmbeddingLookup(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit EmbeddingLookup(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; float GetMaxNorm() const; void SetMaxNorm(float max_norm); diff --git a/mindspore/lite/src/ops/embedding_lookup_sparse.h b/mindspore/lite/src/ops/embedding_lookup_sparse.h index d35f05a805c..a07a01c9919 100644 --- a/mindspore/lite/src/ops/embedding_lookup_sparse.h +++ b/mindspore/lite/src/ops/embedding_lookup_sparse.h @@ -28,10 +28,11 @@ namespace lite { class EmbeddingLookupSparse : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + EmbeddingLookupSparse() = default; explicit EmbeddingLookupSparse(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit EmbeddingLookupSparse(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif std::vector GetSpIds() const; std::vector GetSpWeights() const; float GetMaxNortm() const; diff --git a/mindspore/lite/src/ops/equal.h b/mindspore/lite/src/ops/equal.h index 60fa27b9792..82b0cf362c2 100644 --- a/mindspore/lite/src/ops/equal.h +++ b/mindspore/lite/src/ops/equal.h @@ -28,9 +28,11 @@ namespace lite { class Equal : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE + Equal() = default; explicit Equal(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} -#endif +#else explicit Equal(schema::Primitive *primitive) : Arithmetic(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/exp.h b/mindspore/lite/src/ops/exp.h index fbe17b848e9..a1f7eaa23be 100644 --- a/mindspore/lite/src/ops/exp.h +++ b/mindspore/lite/src/ops/exp.h @@ -28,9 +28,11 @@ namespace lite { class Exp : public ArithmeticSelf { public: #ifdef PRIMITIVE_WRITEABLE + Exp() = default; explicit Exp(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} -#endif +#else explicit Exp(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/expand_dims.h b/mindspore/lite/src/ops/expand_dims.h index 45a040cc3af..36404f1b257 100644 --- a/mindspore/lite/src/ops/expand_dims.h +++ b/mindspore/lite/src/ops/expand_dims.h @@ -28,10 +28,11 @@ namespace lite { class ExpandDims : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + ExpandDims() = default; explicit ExpandDims(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit ExpandDims(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetDim() const; void SetDim(int dim); diff --git a/mindspore/lite/src/ops/fake_quant_with_min_max_vars.h b/mindspore/lite/src/ops/fake_quant_with_min_max_vars.h index e1d3babb3b8..ecc89074db2 100644 --- a/mindspore/lite/src/ops/fake_quant_with_min_max_vars.h +++ b/mindspore/lite/src/ops/fake_quant_with_min_max_vars.h @@ -28,10 +28,11 @@ namespace lite { class FakeQuantWithMinMaxVars : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + FakeQuantWithMinMaxVars() = default; explicit FakeQuantWithMinMaxVars(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit FakeQuantWithMinMaxVars(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif bool GetNarrowRange() const; int GetNumBits() const; void SetNarrowRange(bool narrow_range); diff --git a/mindspore/lite/src/ops/fill.h b/mindspore/lite/src/ops/fill.h index e7766722162..388e11bea7c 100644 --- a/mindspore/lite/src/ops/fill.h +++ b/mindspore/lite/src/ops/fill.h @@ -28,10 +28,11 @@ namespace lite { class Fill : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Fill() = default; explicit Fill(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Fill(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetDims() const; void SetDims(const std::vector &dims); diff --git a/mindspore/lite/src/ops/flatten.cc b/mindspore/lite/src/ops/flatten.cc index 8298824334c..333351922e9 100644 --- a/mindspore/lite/src/ops/flatten.cc +++ b/mindspore/lite/src/ops/flatten.cc @@ -15,6 +15,7 @@ */ #include "src/ops/flatten.h" +#include namespace mindspore { namespace lite { @@ -48,5 +49,14 @@ int Flatten::InferShape(std::vector inputs_, std::vectorset_shape(output_shape); return RET_OK; } +#ifdef PRIMITIVE_WRITEABLE +int Flatten::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + this->primitive_ = new (schema::PrimitiveT); + auto attr = std::make_unique(); + this->primitive_->value.type = schema::PrimitiveType_Flatten; + this->primitive_->value.value = attr.release(); + return RET_OK; +} +#endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/flatten.h b/mindspore/lite/src/ops/flatten.h index 1ab44c577f0..d7ab3baa75f 100644 --- a/mindspore/lite/src/ops/flatten.h +++ b/mindspore/lite/src/ops/flatten.h @@ -28,11 +28,14 @@ namespace lite { class Flatten : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Flatten() = default; explicit Flatten(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Flatten(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; + + int UnPackAttr(const Primitive &prim, const std::vector &inputs); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/floor.h b/mindspore/lite/src/ops/floor.h index dc13935a980..5fc010249f2 100644 --- a/mindspore/lite/src/ops/floor.h +++ b/mindspore/lite/src/ops/floor.h @@ -28,9 +28,11 @@ namespace lite { class Floor : public ArithmeticSelf { public: #ifdef PRIMITIVE_WRITEABLE + Floor() = default; explicit Floor(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} -#endif +#else explicit Floor(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/floor_div.h b/mindspore/lite/src/ops/floor_div.h index 2920e0f666c..a95c1f709b5 100644 --- a/mindspore/lite/src/ops/floor_div.h +++ b/mindspore/lite/src/ops/floor_div.h @@ -28,9 +28,11 @@ namespace lite { class FloorDiv : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE + FloorDiv() = default; explicit FloorDiv(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} -#endif +#else explicit FloorDiv(schema::Primitive *primitive) : Arithmetic(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/floor_mod.h b/mindspore/lite/src/ops/floor_mod.h index 294b1b34ffb..f20eb6cc49d 100644 --- a/mindspore/lite/src/ops/floor_mod.h +++ b/mindspore/lite/src/ops/floor_mod.h @@ -28,9 +28,11 @@ namespace lite { class FloorMod : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE + FloorMod() = default; explicit FloorMod(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} -#endif +#else explicit FloorMod(schema::Primitive *primitive) : Arithmetic(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/full_connection.h b/mindspore/lite/src/ops/full_connection.h index 4ed808bbaf9..7bcb9b11668 100644 --- a/mindspore/lite/src/ops/full_connection.h +++ b/mindspore/lite/src/ops/full_connection.h @@ -28,10 +28,11 @@ namespace lite { class FullConnection : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + FullConnection() = default; explicit FullConnection(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit FullConnection(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; bool GetHasBias() const; int GetAxis() const; diff --git a/mindspore/lite/src/ops/fused_batchnorm.h b/mindspore/lite/src/ops/fused_batchnorm.h index dd96a7e533d..729cc934c66 100644 --- a/mindspore/lite/src/ops/fused_batchnorm.h +++ b/mindspore/lite/src/ops/fused_batchnorm.h @@ -28,10 +28,11 @@ namespace lite { class FusedBatchNorm : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + FusedBatchNorm() = default; explicit FusedBatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit FusedBatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif float GetEpsilon() const; float GetMomentum() const; int GetSpatial() const; diff --git a/mindspore/lite/src/ops/gather.h b/mindspore/lite/src/ops/gather.h index dac98f8ca06..0006b190eb6 100644 --- a/mindspore/lite/src/ops/gather.h +++ b/mindspore/lite/src/ops/gather.h @@ -28,10 +28,11 @@ namespace lite { class Gather : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Gather() = default; explicit Gather(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Gather(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; int GetBatchDims() const; diff --git a/mindspore/lite/src/ops/gather_nd.h b/mindspore/lite/src/ops/gather_nd.h index 43d8f963db6..7f0b1a79376 100644 --- a/mindspore/lite/src/ops/gather_nd.h +++ b/mindspore/lite/src/ops/gather_nd.h @@ -28,10 +28,11 @@ namespace lite { class GatherNd : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + GatherNd() = default; explicit GatherNd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit GatherNd(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetBatchDims() const; void SetBatchDims(int batch_dims); diff --git a/mindspore/lite/src/ops/greater.h b/mindspore/lite/src/ops/greater.h index f959deb966c..3547efbe74e 100644 --- a/mindspore/lite/src/ops/greater.h +++ b/mindspore/lite/src/ops/greater.h @@ -27,9 +27,11 @@ namespace lite { class Greater : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE + Greater() = default; explicit Greater(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} -#endif +#else explicit Greater(schema::Primitive *primitive) : Arithmetic(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/greater_equal.h b/mindspore/lite/src/ops/greater_equal.h index 72d74689f1f..5b97a43c9b7 100644 --- a/mindspore/lite/src/ops/greater_equal.h +++ b/mindspore/lite/src/ops/greater_equal.h @@ -28,9 +28,11 @@ namespace lite { class GreaterEqual : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE + GreaterEqual() = default; explicit GreaterEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} -#endif +#else explicit GreaterEqual(schema::Primitive *primitive) : Arithmetic(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/l2_norm.h b/mindspore/lite/src/ops/l2_norm.h index fe4a242fbe3..a54b0e91b3d 100644 --- a/mindspore/lite/src/ops/l2_norm.h +++ b/mindspore/lite/src/ops/l2_norm.h @@ -28,10 +28,11 @@ namespace lite { class L2Norm : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + L2Norm() = default; explicit L2Norm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit L2Norm(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif std::vector GetAxis() const; float GetEpsilon() const; void SetAxis(const std::vector &axis); diff --git a/mindspore/lite/src/ops/leaky_relu.h b/mindspore/lite/src/ops/leaky_relu.h index 02469ec5b89..1e021f35ba1 100644 --- a/mindspore/lite/src/ops/leaky_relu.h +++ b/mindspore/lite/src/ops/leaky_relu.h @@ -28,10 +28,11 @@ namespace lite { class LeakyReLU : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + LeakyReLU() = default; explicit LeakyReLU(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit LeakyReLU(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif float GetNegativeSlope() const; void SetNegativeSlope(float negative_slope); }; diff --git a/mindspore/lite/src/ops/less.h b/mindspore/lite/src/ops/less.h index a5ccda4e071..d0205905e41 100644 --- a/mindspore/lite/src/ops/less.h +++ b/mindspore/lite/src/ops/less.h @@ -28,9 +28,11 @@ namespace lite { class Less : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE + Less() = default; explicit Less(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} -#endif +#else explicit Less(schema::Primitive *primitive) : Arithmetic(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/less_equal.h b/mindspore/lite/src/ops/less_equal.h index 52cdf561689..c1e96ecb997 100644 --- a/mindspore/lite/src/ops/less_equal.h +++ b/mindspore/lite/src/ops/less_equal.h @@ -28,9 +28,11 @@ namespace lite { class LessEqual : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE + LessEqual() = default; explicit LessEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} -#endif +#else explicit LessEqual(schema::Primitive *primitive) : Arithmetic(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/local_response_normalization.h b/mindspore/lite/src/ops/local_response_normalization.h index 7b19e08961b..67557c5147c 100644 --- a/mindspore/lite/src/ops/local_response_normalization.h +++ b/mindspore/lite/src/ops/local_response_normalization.h @@ -28,10 +28,11 @@ namespace lite { class LocalResponseNormalization : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + LocalResponseNormalization() = default; explicit LocalResponseNormalization(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit LocalResponseNormalization(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int GetDepthRadius() const; float GetBias() const; float GetAlpha() const; diff --git a/mindspore/lite/src/ops/log.h b/mindspore/lite/src/ops/log.h index 88af6a906a7..0243af016a4 100644 --- a/mindspore/lite/src/ops/log.h +++ b/mindspore/lite/src/ops/log.h @@ -28,9 +28,11 @@ namespace lite { class Log : public ArithmeticSelf { public: #ifdef PRIMITIVE_WRITEABLE + Log() = default; explicit Log(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} -#endif +#else explicit Log(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/logical_and.h b/mindspore/lite/src/ops/logical_and.h index 5cbacaf63b3..a9a6bda890c 100644 --- a/mindspore/lite/src/ops/logical_and.h +++ b/mindspore/lite/src/ops/logical_and.h @@ -28,9 +28,11 @@ namespace lite { class LogicalAnd : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE + LogicalAnd() = default; explicit LogicalAnd(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} -#endif +#else explicit LogicalAnd(schema::Primitive *primitive) : Arithmetic(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/logical_not.h b/mindspore/lite/src/ops/logical_not.h index c1cf519a977..b6fc369f585 100644 --- a/mindspore/lite/src/ops/logical_not.h +++ b/mindspore/lite/src/ops/logical_not.h @@ -28,9 +28,11 @@ namespace lite { class LogicalNot : public ArithmeticSelf { public: #ifdef PRIMITIVE_WRITEABLE + LogicalNot() = default; explicit LogicalNot(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} -#endif +#else explicit LogicalNot(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/logical_or.h b/mindspore/lite/src/ops/logical_or.h index 327a164d549..3571dd7086e 100644 --- a/mindspore/lite/src/ops/logical_or.h +++ b/mindspore/lite/src/ops/logical_or.h @@ -28,9 +28,11 @@ namespace lite { class LogicalOr : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE + LogicalOr() = default; explicit LogicalOr(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} -#endif +#else explicit LogicalOr(schema::Primitive *primitive) : Arithmetic(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/lrn.h b/mindspore/lite/src/ops/lrn.h index 8f718114285..0dd7b173476 100644 --- a/mindspore/lite/src/ops/lrn.h +++ b/mindspore/lite/src/ops/lrn.h @@ -28,10 +28,11 @@ namespace lite { class Lrn : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Lrn() = default; explicit Lrn(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Lrn(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif float GetAlpha() const; float GetBeta() const; float GetBias() const; diff --git a/mindspore/lite/src/ops/lstm.h b/mindspore/lite/src/ops/lstm.h index 47adedff167..5260bed3f80 100644 --- a/mindspore/lite/src/ops/lstm.h +++ b/mindspore/lite/src/ops/lstm.h @@ -28,10 +28,11 @@ namespace lite { class Lstm : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Lstm() = default; explicit Lstm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Lstm(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; bool GetBidirection() const; void SetBidirection(bool bidirection); diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.h b/mindspore/lite/src/ops/make_tuple.cc similarity index 55% rename from mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.h rename to mindspore/lite/src/ops/make_tuple.cc index 28d3c96229a..ff102f13b35 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.h +++ b/mindspore/lite/src/ops/make_tuple.cc @@ -13,18 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_ANF_QUANT_PARSER_H -#define MINDSPORE_ANF_QUANT_PARSER_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" -#include -namespace mindspore::lite { -class AnfQuantPopulater : public AnfNodePopulater { - public: - AnfQuantPopulater() = default; - ~AnfQuantPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) override; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_ANF_QUANT_PARSER_H +#include "src/ops/make_tuple.h" +#include +#include + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int MakeTuple::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + this->primitive_ = new (schema::PrimitiveT); + auto attr = std::make_unique(); + this->primitive_->value.type = schema::PrimitiveType_MakeTuple; + this->primitive_->value.value = attr.release(); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/make_tuple.h b/mindspore/lite/src/ops/make_tuple.h new file mode 100644 index 00000000000..2559644997e --- /dev/null +++ b/mindspore/lite/src/ops/make_tuple.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_ +#define LITE_MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_ +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class MakeTuple : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MakeTuple() = default; + explicit MakeTuple(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs); +#else + explicit MakeTuple(schema::Primitive *primitive) : PrimitiveC(primitive) {} +#endif +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_SRC_OPS_MAKE_TUPLE_H_ diff --git a/mindspore/lite/src/ops/matmul.cc b/mindspore/lite/src/ops/matmul.cc index a68de23fb71..3eba5716cfa 100644 --- a/mindspore/lite/src/ops/matmul.cc +++ b/mindspore/lite/src/ops/matmul.cc @@ -15,7 +15,11 @@ */ #include "src/ops/matmul.h" +#include #include +#ifdef PRIMITIVE_WRITEABLE +#include "tools/converter/quantizer/quantize_util.h" +#endif namespace mindspore { namespace lite { @@ -26,6 +30,102 @@ bool MatMul::GetTransposeB() const { return this->primitive_->value.AsMatMul()-> void MatMul::SetTransposeA(bool transpose_a) { this->primitive_->value.AsMatMul()->transposeA = transpose_a; } void MatMul::SetTransposeB(bool transpose_b) { this->primitive_->value.AsMatMul()->transposeB = transpose_b; } +void MatMul::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { + constexpr float qmin = 0; + constexpr float qmax = 255; + *mMin = static_cast((qmin - mean) / stdDev); + *mMax = static_cast((qmax - mean) / stdDev); +} + +void MatMul::PopulaterQuantParam(const Primitive &prim, + std::vector> *vecInputQuantParam, + std::vector> *vecOutputQuantParam) { + auto narrow_range = prim.GetAttr("narrow_range"); + bool narrowRangeQuantParam = GetValue(narrow_range); + auto num_bits = prim.GetAttr("num_bits"); + int32_t numbitsRangeQuantParam = GetValue(num_bits); + + std::vector quants; + schema::QuantParamT quantParam; + auto mean = prim.GetAttr("mean"); + auto std_dev = prim.GetAttr("std_dev"); + if (mean != nullptr && std_dev != nullptr) { + auto meanQuantOaram = GetValue(mean); + double stddevQuantOaram = GetValue(std_dev); + float mMin = 0.0; + float mMax = 0.0; + CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); + quantParam.min = mMin; + quantParam.max = mMax; + } else { + auto inputMin = prim.GetAttr("input_minq"); + auto inputMax = prim.GetAttr("input_maxq"); + auto inputMinPtr = inputMin->cast(); + auto inputMaxPtr = inputMax->cast(); + float *minBuf = static_cast(inputMinPtr->Data()); + float *maxBuf = static_cast(inputMaxPtr->Data()); + quantParam.min = *minBuf; + quantParam.max = *maxBuf; + } + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, + numbitsRangeQuantParam); + quants.emplace_back(quantParam); + vecInputQuantParam->emplace_back(quants); + + quants.clear(); + auto filterMin = prim.GetAttr("filter_minq"); + auto filterMax = prim.GetAttr("filter_maxq"); + if (filterMin != nullptr && filterMax != nullptr) { + auto filterMinPtr = filterMin->cast(); + auto filterMaxPtr = filterMax->cast(); + float *minBuf = static_cast(filterMinPtr->Data()); + float *maxBuf = static_cast(filterMaxPtr->Data()); + for (int i = 0; i < filterMinPtr->DataSize(); ++i) { + quantParam.min = *(minBuf++); + quantParam.max = *(maxBuf++); + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, + numbitsRangeQuantParam); + quants.emplace_back(quantParam); + } + vecInputQuantParam->emplace_back(quants); + } + + quants.clear(); + auto outputMin = prim.GetAttr("output_minq"); + auto outputMax = prim.GetAttr("output_maxq"); + if (outputMin != nullptr && outputMax != nullptr) { + auto outputMinPtr = outputMin->cast(); + auto outputMaxPtr = outputMax->cast(); + float *minBuf = static_cast(outputMinPtr->Data()); + float *maxBuf = static_cast(outputMaxPtr->Data()); + quantParam.min = *minBuf; + quantParam.max = *maxBuf; + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, + numbitsRangeQuantParam); + quants.emplace_back(quantParam); + vecOutputQuantParam->emplace_back(quants); + } +} + +int MatMul::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + this->primitive_ = new (schema::PrimitiveT); + + auto attr = std::make_unique(); + attr->transposeA = GetValue(prim.GetAttr("transpose_a")); + attr->transposeB = GetValue(prim.GetAttr("transpose_b")); + + this->primitive_->value.type = schema::PrimitiveType_MatMul; + this->primitive_->value.value = attr.release(); + if (GetQuantType() == schema::QuantType_AwareTraining) { + std::vector> vecInputQuantParam; + std::vector> vecOutputQuantParam; + PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); + SetInputQuantParam(vecInputQuantParam); + SetOutputQuantParam(vecOutputQuantParam); + } + return RET_OK; +} + #else bool MatMul::GetTransposeA() const { return this->primitive_->value_as_MatMul()->transposeA(); } diff --git a/mindspore/lite/src/ops/matmul.h b/mindspore/lite/src/ops/matmul.h index c25f4d9b938..2295aa1d399 100644 --- a/mindspore/lite/src/ops/matmul.h +++ b/mindspore/lite/src/ops/matmul.h @@ -20,18 +20,29 @@ #include #include #include -#include "ir/dtype/type_id.h" #include "src/ops/primitive_c.h" +#include "ir/dtype/type_id.h" namespace mindspore { namespace lite { class MatMul : public PrimitiveC { - public: #ifdef PRIMITIVE_WRITEABLE + public: + MatMul() = default; explicit MatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif - explicit MatMul(schema::Primitive *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs); + private: + void PopulaterQuantParam(const Primitive &prim, std::vector> *vecInputQuantParam, + std::vector> *vecOutputQuantParam); + void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); +#else + + public: + explicit MatMul(schema::Primitive *primitive) : PrimitiveC(primitive) {} +#endif + + public: int InferShape(std::vector inputs_, std::vector outputs_) override; bool GetTransposeA() const; bool GetTransposeB() const; diff --git a/mindspore/lite/src/ops/matrix_diag.h b/mindspore/lite/src/ops/matrix_diag.h index 14c41d60a8e..3b54632543f 100644 --- a/mindspore/lite/src/ops/matrix_diag.h +++ b/mindspore/lite/src/ops/matrix_diag.h @@ -28,10 +28,11 @@ namespace lite { class MatrixDiag : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + MatrixDiag() = default; explicit MatrixDiag(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit MatrixDiag(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int GetK() const; int GetNumRows() const; int GetNumCols() const; diff --git a/mindspore/lite/src/ops/maximum.h b/mindspore/lite/src/ops/maximum.h index ba391b71acf..d123c559672 100644 --- a/mindspore/lite/src/ops/maximum.h +++ b/mindspore/lite/src/ops/maximum.h @@ -28,9 +28,11 @@ namespace lite { class Maximum : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE + Maximum() = default; explicit Maximum(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} -#endif +#else explicit Maximum(schema::Primitive *primitive) : Arithmetic(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/mean.h b/mindspore/lite/src/ops/mean.h index 873cc75055b..0c87275ec17 100644 --- a/mindspore/lite/src/ops/mean.h +++ b/mindspore/lite/src/ops/mean.h @@ -28,10 +28,11 @@ namespace lite { class Mean : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Mean() = default; explicit Mean(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Mean(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetAxis() const; bool GetKeepDims() const; diff --git a/mindspore/lite/src/ops/minimum.h b/mindspore/lite/src/ops/minimum.h index 1b11fdaca35..9606ab5c216 100644 --- a/mindspore/lite/src/ops/minimum.h +++ b/mindspore/lite/src/ops/minimum.h @@ -28,9 +28,11 @@ namespace lite { class Minimum : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE + Minimum() = default; explicit Minimum(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} -#endif +#else explicit Minimum(schema::Primitive *primitive) : Arithmetic(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/mul.cc b/mindspore/lite/src/ops/mul.cc index fea06d35388..ffe61d852fe 100644 --- a/mindspore/lite/src/ops/mul.cc +++ b/mindspore/lite/src/ops/mul.cc @@ -15,6 +15,7 @@ */ #include "src/ops/mul.h" +#include namespace mindspore { namespace lite { @@ -24,6 +25,14 @@ int Mul::GetActivationType() const { return this->primitive_->value.AsMul()->act void Mul::SetActivationType(int activation_type) { this->primitive_->value.AsMul()->activationType = (schema::ActivationType)activation_type; } +int Mul::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + this->primitive_ = new (schema::PrimitiveT); + auto attr = std::make_unique(); + this->primitive_->value.type = schema::PrimitiveType_Mul; + this->primitive_->value.value = attr.release(); + + return RET_OK; +} #else diff --git a/mindspore/lite/src/ops/mul.h b/mindspore/lite/src/ops/mul.h index 97ca5fccabc..1962ae600fa 100644 --- a/mindspore/lite/src/ops/mul.h +++ b/mindspore/lite/src/ops/mul.h @@ -28,12 +28,15 @@ namespace lite { class Mul : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE + Mul() = default; explicit Mul(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} -#endif +#else explicit Mul(schema::Primitive *primitive) : Arithmetic(primitive) {} - +#endif int GetActivationType() const; void SetActivationType(int activation_type); + + int UnPackAttr(const Primitive &prim, const std::vector &inputs); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/nchw2nhwc.h b/mindspore/lite/src/ops/nchw2nhwc.h index 3ea40ecc67f..f47ff6afd4b 100644 --- a/mindspore/lite/src/ops/nchw2nhwc.h +++ b/mindspore/lite/src/ops/nchw2nhwc.h @@ -28,10 +28,11 @@ namespace lite { class Nchw2Nhwc : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Nchw2Nhwc() = default; explicit Nchw2Nhwc(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Nchw2Nhwc(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite diff --git a/mindspore/lite/src/ops/nhwc2nchw.h b/mindspore/lite/src/ops/nhwc2nchw.h index c7f05f845b1..232d7ab3879 100644 --- a/mindspore/lite/src/ops/nhwc2nchw.h +++ b/mindspore/lite/src/ops/nhwc2nchw.h @@ -28,10 +28,11 @@ namespace lite { class Nhwc2Nchw : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Nhwc2Nchw() = default; explicit Nhwc2Nchw(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Nhwc2Nchw(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite diff --git a/mindspore/lite/src/ops/not_equal.h b/mindspore/lite/src/ops/not_equal.h index f7c9a77bacd..2ec16ada795 100644 --- a/mindspore/lite/src/ops/not_equal.h +++ b/mindspore/lite/src/ops/not_equal.h @@ -28,9 +28,11 @@ namespace lite { class NotEqual : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE + NotEqual() = default; explicit NotEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} -#endif +#else explicit NotEqual(schema::Primitive *primitive) : Arithmetic(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/one_hot.h b/mindspore/lite/src/ops/one_hot.h index e9703c0d606..49e494b3dd5 100644 --- a/mindspore/lite/src/ops/one_hot.h +++ b/mindspore/lite/src/ops/one_hot.h @@ -28,10 +28,11 @@ namespace lite { class OneHot : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + OneHot() = default; explicit OneHot(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit OneHot(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; void SetAxis(int axis); diff --git a/mindspore/lite/src/ops/pad.h b/mindspore/lite/src/ops/pad.h index bedeabf2263..8770133662d 100644 --- a/mindspore/lite/src/ops/pad.h +++ b/mindspore/lite/src/ops/pad.h @@ -28,10 +28,11 @@ namespace lite { class Pad : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Pad() = default; explicit Pad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Pad(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetPaddings() const; int GetPaddingMode() const; diff --git a/mindspore/lite/src/ops/permute.h b/mindspore/lite/src/ops/permute.h index 2b5ee61efd9..f2f0fdb03bb 100644 --- a/mindspore/lite/src/ops/permute.h +++ b/mindspore/lite/src/ops/permute.h @@ -28,10 +28,11 @@ namespace lite { class Permute : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Permute() = default; explicit Permute(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Permute(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif std::vector GetOrder() const; void SetOrder(const std::vector &order); }; diff --git a/mindspore/lite/src/ops/pooling.cc b/mindspore/lite/src/ops/pooling.cc index f24ec9600fe..d71dc33053e 100644 --- a/mindspore/lite/src/ops/pooling.cc +++ b/mindspore/lite/src/ops/pooling.cc @@ -15,6 +15,9 @@ */ #include "src/ops/pooling.h" +#include +#include +#include namespace mindspore { namespace lite { @@ -52,6 +55,47 @@ void Pooling::SetRoundMode(int round_mode) { this->primitive_->value.AsPooling()->roundMode = (schema::RoundMode)round_mode; } +int Pooling::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + this->primitive_ = new (schema::PrimitiveT); + auto attr = std::make_unique(); + if (prim.instance_name() == "MaxPool") { + attr->poolingMode = schema::PoolMode_MAX_POOLING; + } else if (prim.instance_name() == "MeanPool") { + attr->poolingMode = schema::PoolMode_MEAN_POOLING; + } + + auto format = GetValue(prim.GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format_NHWC; + } else { + attr->format = schema::Format_NUM_OF_FORMAT; + } + + auto pad_mode = GetValue(prim.GetAttr("padding")); + if (pad_mode == "VALID") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "SAME") { + attr->padMode = schema::PadMode_SAME; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + + auto kernel_size = GetValue>(prim.GetAttr("ksize")); + attr->windowH = kernel_size[2]; + attr->windowW = kernel_size[3]; + + auto stride = GetValue>(prim.GetAttr("strides")); + attr->strideH = stride[2]; + attr->strideW = stride[3]; + + this->primitive_->value.type = schema::PrimitiveType_Pooling; + this->primitive_->value.value = attr.release(); + + return RET_OK; +} + #else int Pooling::GetFormat() const { return this->primitive_->value_as_Pooling()->format(); } diff --git a/mindspore/lite/src/ops/pooling.h b/mindspore/lite/src/ops/pooling.h index 77bf8f261ac..eeac0e3b968 100644 --- a/mindspore/lite/src/ops/pooling.h +++ b/mindspore/lite/src/ops/pooling.h @@ -28,10 +28,11 @@ namespace lite { class Pooling : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Pooling() = default; explicit Pooling(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Pooling(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; int GetPoolingMode() const; @@ -65,6 +66,8 @@ class Pooling : public PrimitiveC { int PadLeft() const; int PadRight() const; + int UnPackAttr(const Primitive &prim, const std::vector &inputs); + protected: int pad_u_ = 0; int pad_d_ = 0; diff --git a/mindspore/lite/src/ops/pooling_grad.h b/mindspore/lite/src/ops/pooling_grad.h index 4314b22d3a5..490bd7ddcbb 100644 --- a/mindspore/lite/src/ops/pooling_grad.h +++ b/mindspore/lite/src/ops/pooling_grad.h @@ -28,10 +28,11 @@ namespace lite { class PoolingGrad : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + PoolingGrad() = default; explicit PoolingGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit PoolingGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int GetFormat() const; int GetPoolingMode() const; bool GetGlobal() const; diff --git a/mindspore/lite/src/ops/power.h b/mindspore/lite/src/ops/power.h index 1f252cf2c93..764da4028c3 100644 --- a/mindspore/lite/src/ops/power.h +++ b/mindspore/lite/src/ops/power.h @@ -28,10 +28,11 @@ namespace lite { class Power : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Power() = default; explicit Power(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Power(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; float GetPower() const; float GetScale() const; diff --git a/mindspore/lite/src/ops/power_grad.h b/mindspore/lite/src/ops/power_grad.h index 29bd56515d6..3969197bfba 100644 --- a/mindspore/lite/src/ops/power_grad.h +++ b/mindspore/lite/src/ops/power_grad.h @@ -28,10 +28,11 @@ namespace lite { class PowerGrad : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + PowerGrad() = default; explicit PowerGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit PowerGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif float GetPower() const; float GetScale() const; float GetShift() const; diff --git a/mindspore/lite/src/ops/prelu.h b/mindspore/lite/src/ops/prelu.h index f57fd9f5c91..7dd71e9dc3b 100644 --- a/mindspore/lite/src/ops/prelu.h +++ b/mindspore/lite/src/ops/prelu.h @@ -28,10 +28,11 @@ namespace lite { class Prelu : public Activation { public: #ifdef PRIMITIVE_WRITEABLE + Prelu() = default; explicit Prelu(schema::PrimitiveT *primitive) : Activation(primitive) {} -#endif +#else explicit Prelu(schema::Primitive *primitive) : Activation(primitive) {} - +#endif std::vector GetSlope() const; void SetSlope(const std::vector &slope); }; diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index b98c09e1bd1..e3fcd3039f7 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -112,6 +112,10 @@ #include "src/ops/where.h" #include "src/ops/scatter_nd.h" #include "src/ops/constant_of_shape.h" +#include "src/ops/dequant.h" +#include "src/ops/make_tuple.h" +#include "src/ops/quant.h" +#include "src/ops/tuple_get_item.h" namespace mindspore { namespace lite { @@ -168,8 +172,253 @@ std::shared_ptr GetTupleGetItemPrim() { return std::make_shared(tuple_get_item_primitiveT); } -PrimitiveC *PrimitiveC::CreatePrimitive(mindspore::schema::Primitive *primitive) { +template ::value>> +std::shared_ptr NewPrimitiveC(const Primitive &prim, const std::vector &inputs) { + auto primc = std::make_shared(); + if (primc == nullptr) { + MS_LOG(ERROR) << "make_shared PrimitiveC failed"; + return nullptr; + } + auto ret = primc->UnPackAttr(prim, inputs); + if (ret != RET_OK) { + MS_LOG(ERROR) << "UnPackAttr failed"; + return nullptr; + } + return primc; +} + +std::shared_ptr PrimitiveC::UnPackFromPrimitive(const Primitive &prim, + const std::vector &inputs) { + const auto &op_type = prim.name(); + if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid") { + return NewPrimitiveC(prim, inputs); + } else if (op_type == "BatchNorm") { + return NewPrimitiveC(prim, inputs); + } else if (op_type == "BiasAdd") { + return NewPrimitiveC(prim, inputs); + } else if (op_type == "Concat") { + return NewPrimitiveC(prim, inputs); + } else if (op_type == "Conv2D") { + return NewPrimitiveC(prim, inputs); + } else if (op_type == "DepthwiseConv2dNative" || op_type == "DepthwiseConv2D") { + return NewPrimitiveC(prim, inputs); + } else if (op_type == "Dequant") { + return NewPrimitiveC(prim, inputs); + } else if (op_type == "Flatten") { + return NewPrimitiveC(prim, inputs); + } else if (op_type == "make_tuple") { + return NewPrimitiveC(prim, inputs); + } else if (op_type == "MatMul") { + return NewPrimitiveC(prim, inputs); + } else if (op_type == "Mul") { + return NewPrimitiveC(prim, inputs); + } else if (op_type == "MaxPool") { + return NewPrimitiveC(prim, inputs); + } else if (op_type == "Quant") { + return NewPrimitiveC(prim, inputs); + } else if (op_type == "ReduceMean") { + return NewPrimitiveC(prim, inputs); + } else if (op_type == "Reshape") { + return NewPrimitiveC(prim, inputs); + } else if (op_type == "TensorAdd") { + return NewPrimitiveC(prim, inputs); + } else if (op_type == "Transpose") { + return NewPrimitiveC(prim, inputs); + } else if (op_type == "tuple_getitem") { + return NewPrimitiveC(prim, inputs); + } else { + MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromPrimitive : " << op_type; + return nullptr; + } +} + +PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT *primitive) { MS_ASSERT(primitive != nullptr); + auto op_type = primitive->value.type; + switch (op_type) { + case schema::PrimitiveType_SoftMax: + return new SoftMax(primitive); + case schema::PrimitiveType_Activation: + return new Activation(primitive); + case schema::PrimitiveType_Conv2D: + return new Conv2D(primitive); + case schema::PrimitiveType_DeConv2D: + return new DeConv2D(primitive); + case schema::PrimitiveType_Reduce: + return new Reduce(primitive); + case schema::PrimitiveType_Pooling: + return new Pooling(primitive); + case schema::PrimitiveType_ROIPooling: + return new ROIPooling(primitive); + case schema::PrimitiveType_DepthwiseConv2D: + return new DepthwiseConv2D(primitive); + case schema::PrimitiveType_FusedBatchNorm: + return new FusedBatchNorm(primitive); + case schema::PrimitiveType_BatchNorm: + return new BatchNorm(primitive); + case schema::PrimitiveType_FullConnection: + return new FullConnection(primitive); + case schema::PrimitiveType_Power: + return new Power(primitive); + case schema::PrimitiveType_Pad: + return new Pad(primitive); + case schema::PrimitiveType_Range: + return new Range(primitive); + case schema::PrimitiveType_Mul: + return new Mul(primitive); + case schema::PrimitiveType_Add: + return new Add(primitive); + case schema::PrimitiveType_Sub: + return new Sub(primitive); + case schema::PrimitiveType_Div: + return new Div(primitive); + case schema::PrimitiveType_BiasAdd: + return new BiasAdd(primitive); + case schema::PrimitiveType_ExpandDims: + return new ExpandDims(primitive); + case schema::PrimitiveType_ArgMax: + return new ArgMax(primitive); + case schema::PrimitiveType_ArgMin: + return new ArgMin(primitive); + case schema::PrimitiveType_Cast: + return new Cast(primitive); + case schema::PrimitiveType_Reshape: + return new Reshape(primitive); + case schema::PrimitiveType_Scale: + return new Scale(primitive); + case schema::PrimitiveType_Eltwise: + return new Eltwise(primitive); + case schema::PrimitiveType_Ceil: + return new Ceil(primitive); + case schema::PrimitiveType_Concat: + return new Concat(primitive); + case schema::PrimitiveType_Fill: + return new Fill(primitive); + case schema::PrimitiveType_Nhwc2Nchw: + return new Nhwc2Nchw(primitive); + case schema::PrimitiveType_Nchw2Nhwc: + return new Nchw2Nhwc(primitive); + case schema::PrimitiveType_Transpose: + return new Transpose(primitive); + case schema::PrimitiveType_Slice: + return new Slice(primitive); + case schema::PrimitiveType_Squeeze: + return new Squeeze(primitive); + case schema::PrimitiveType_Flatten: + return new Flatten(primitive); + case schema::PrimitiveType_Mean: + return new Mean(primitive); + case schema::PrimitiveType_Stack: + return new Stack(primitive); + case schema::PrimitiveType_Crop: + return new Crop(primitive); + case schema::PrimitiveType_SquaredDifference: + return new SquaredDifference(primitive); + case schema::PrimitiveType_AddN: + return new AddN(primitive); + case schema::PrimitiveType_Abs: + return new Abs(primitive); + case schema::PrimitiveType_Sin: + return new Sin(primitive); + case schema::PrimitiveType_Cos: + return new Cos(primitive); + case schema::PrimitiveType_Log: + return new Log(primitive); + case schema::PrimitiveType_Sqrt: + return new Sqrt(primitive); + case schema::PrimitiveType_Rsqrt: + return new Rsqrt(primitive); + case schema::PrimitiveType_Square: + return new Square(primitive); + case schema::PrimitiveType_Exp: + return new Exp(primitive); + case schema::PrimitiveType_Gather: + return new Gather(primitive); + case schema::PrimitiveType_GatherNd: + return new GatherNd(primitive); + case schema::PrimitiveType_LocalResponseNormalization: + return new LocalResponseNormalization(primitive); + case schema::PrimitiveType_Maximum: + return new Maximum(primitive); + case schema::PrimitiveType_Minimum: + return new Minimum(primitive); + case schema::PrimitiveType_StridedSlice: + return new StridedSlice(primitive); + case schema::PrimitiveType_Prelu: + return new Prelu(primitive); + case schema::PrimitiveType_CaffePReLU: + return new CaffePReLU(primitive); + case schema::PrimitiveType_Round: + return new Round(primitive); + case schema::PrimitiveType_Reverse: + return new Reverse(primitive); + case schema::PrimitiveType_ReverseSequence: + return new ReverseSequence(primitive); + case schema::PrimitiveType_LogicalAnd: + return new LogicalAnd(primitive); + case schema::PrimitiveType_LogicalOr: + return new LogicalOr(primitive); + case schema::PrimitiveType_LogicalNot: + return new LogicalNot(primitive); + case schema::PrimitiveType_FloorDiv: + return new FloorDiv(primitive); + case schema::PrimitiveType_FloorMod: + return new FloorMod(primitive); + case schema::PrimitiveType_Equal: + return new Equal(primitive); + case schema::PrimitiveType_NotEqual: + return new NotEqual(primitive); + case schema::PrimitiveType_Less: + return new Less(primitive); + case schema::PrimitiveType_LessEqual: + return new LessEqual(primitive); + case schema::PrimitiveType_Greater: + return new Greater(primitive); + case schema::PrimitiveType_GreaterEqual: + return new GreaterEqual(primitive); + case schema::PrimitiveType_Floor: + return new Floor(primitive); + case schema::PrimitiveType_Split: + return new Split(primitive); + case schema::PrimitiveType_OneHot: + return new OneHot(primitive); + case schema::PrimitiveType_PriorBox: + return new PriorBox(primitive); + case schema::PrimitiveType_SpaceToDepth: + return new SpaceToDepth(primitive); + case schema::PrimitiveType_Tile: + return new Tile(primitive); + case schema::PrimitiveType_Resize: + return new Resize(primitive); + case schema::PrimitiveType_Unstack: + return new Unstack(primitive); + case schema::PrimitiveType_Unique: + return new Unique(primitive); + case schema::PrimitiveType_TopK: + return new TopK(primitive); + case schema::PrimitiveType_MatMul: + return new MatMul(primitive); + case schema::PrimitiveType_QuantDTypeCast: + return new QuantDTypeCast(primitive); + case schema::PrimitiveType_EmbeddingLookup: + return new EmbeddingLookup(primitive); + case schema::PrimitiveType_Elu: + return new Elu(primitive); + case schema::PrimitiveType_DeDepthwiseConv2D: + return new DeDepthwiseConv2D(primitive); + case schema::PrimitiveType_Shape: + return new Shape(primitive); + case schema::PrimitiveType_Unsqueeze: + return new Unsqueeze(primitive); + default: + MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromSchemaPrimitiveT : " + << schema::EnumNamePrimitiveType(op_type); + return nullptr; + } +} +#else +PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *primitive) { + MS_EXCEPTION_IF_NULL(primitive); auto op_type = primitive->value_type(); switch (op_type) { case schema::PrimitiveType_SoftMax: @@ -178,12 +427,12 @@ PrimitiveC *PrimitiveC::CreatePrimitive(mindspore::schema::Primitive *primitive) return new Activation(const_cast(primitive)); case schema::PrimitiveType_Conv2D: return new Conv2D(const_cast(primitive)); + case schema::PrimitiveType_DeConv2D: + return new DeConv2D(const_cast(primitive)); case schema::PrimitiveType_Reduce: return new Reduce(const_cast(primitive)); case schema::PrimitiveType_Pooling: return new Pooling(const_cast(primitive)); - case schema::PrimitiveType_ROIPooling: - return new ROIPooling(const_cast(primitive)); case schema::PrimitiveType_DepthwiseConv2D: return new DepthwiseConv2D(const_cast(primitive)); case schema::PrimitiveType_FusedBatchNorm: @@ -194,8 +443,6 @@ PrimitiveC *PrimitiveC::CreatePrimitive(mindspore::schema::Primitive *primitive) return new FullConnection(const_cast(primitive)); case schema::PrimitiveType_Power: return new Power(const_cast(primitive)); - case schema::PrimitiveType_Pad: - return new Pad(const_cast(primitive)); case schema::PrimitiveType_Range: return new Range(const_cast(primitive)); case schema::PrimitiveType_Mul: @@ -218,88 +465,137 @@ PrimitiveC *PrimitiveC::CreatePrimitive(mindspore::schema::Primitive *primitive) return new Cast(const_cast(primitive)); case schema::PrimitiveType_Reshape: return new Reshape(const_cast(primitive)); + case schema::PrimitiveType_Scale: + return new Scale(const_cast(primitive)); case schema::PrimitiveType_Eltwise: return new Eltwise(const_cast(primitive)); - case schema::PrimitiveType_Ceil: - return new Ceil(const_cast(primitive)); case schema::PrimitiveType_Concat: return new Concat(const_cast(primitive)); case schema::PrimitiveType_Fill: return new Fill(const_cast(primitive)); - case schema::PrimitiveType_Nhwc2Nchw: - return new Nhwc2Nchw(const_cast(primitive)); - case schema::PrimitiveType_Nchw2Nhwc: - return new Nchw2Nhwc(const_cast(primitive)); case schema::PrimitiveType_Transpose: return new Transpose(const_cast(primitive)); + case schema::PrimitiveType_Slice: + return new Slice(const_cast(primitive)); case schema::PrimitiveType_Squeeze: return new Squeeze(const_cast(primitive)); + case schema::PrimitiveType_Nchw2Nhwc: + return new Nchw2Nhwc(const_cast(primitive)); + case schema::PrimitiveType_Nhwc2Nchw: + return new Nhwc2Nchw(const_cast(primitive)); + case schema::PrimitiveType_Flatten: + return new Flatten(const_cast(primitive)); + case schema::PrimitiveType_Mean: + return new Mean(const_cast(primitive)); + case schema::PrimitiveType_Stack: + return new Stack(const_cast(primitive)); + case schema::PrimitiveType_Crop: + return new Crop(const_cast(primitive)); case schema::PrimitiveType_SquaredDifference: return new SquaredDifference(const_cast(primitive)); - case schema::PrimitiveType_Split: - return new Split(const_cast(primitive)); + case schema::PrimitiveType_AddN: + return new AddN(const_cast(primitive)); + case schema::PrimitiveType_Abs: + return new Abs(const_cast(primitive)); + case schema::PrimitiveType_Sin: + return new Sin(const_cast(primitive)); + case schema::PrimitiveType_Cos: + return new Cos(const_cast(primitive)); + case schema::PrimitiveType_Log: + return new Log(const_cast(primitive)); + case schema::PrimitiveType_Sqrt: + return new Sqrt(const_cast(primitive)); + case schema::PrimitiveType_Rsqrt: + return new Rsqrt(const_cast(primitive)); + case schema::PrimitiveType_Square: + return new Square(const_cast(primitive)); + case schema::PrimitiveType_Exp: + return new Exp(const_cast(primitive)); + case schema::PrimitiveType_Gather: + return new Gather(const_cast(primitive)); + case schema::PrimitiveType_GatherNd: + return new GatherNd(const_cast(primitive)); + case schema::PrimitiveType_LocalResponseNormalization: + return new LocalResponseNormalization(const_cast(primitive)); + case schema::PrimitiveType_Maximum: + return new Maximum(const_cast(primitive)); + case schema::PrimitiveType_Minimum: + return new Minimum(const_cast(primitive)); + case schema::PrimitiveType_Pad: + return new Pad(const_cast(primitive)); + case schema::PrimitiveType_StridedSlice: + return new StridedSlice(const_cast(primitive)); + case schema::PrimitiveType_Prelu: + return new Prelu(const_cast(primitive)); + case schema::PrimitiveType_CaffePReLU: + return new CaffePReLU(const_cast(primitive)); + case schema::PrimitiveType_Round: + return new Round(const_cast(primitive)); + case schema::PrimitiveType_Reverse: + return new Reverse(const_cast(primitive)); + case schema::PrimitiveType_ReverseSequence: + return new ReverseSequence(const_cast(primitive)); + case schema::PrimitiveType_LogicalAnd: + return new LogicalAnd(const_cast(primitive)); + case schema::PrimitiveType_LogicalOr: + return new LogicalOr(const_cast(primitive)); + case schema::PrimitiveType_LogicalNot: + return new LogicalNot(const_cast(primitive)); case schema::PrimitiveType_FloorDiv: return new FloorDiv(const_cast(primitive)); case schema::PrimitiveType_FloorMod: return new FloorMod(const_cast(primitive)); - case schema::PrimitiveType_Reverse: - return new Reverse(const_cast(primitive)); - case schema::PrimitiveType_Scale: - return new Scale(const_cast(primitive)); - case schema::PrimitiveType_GatherNd: - return new GatherNd(const_cast(primitive)); - case schema::PrimitiveType_Tile: - return new Tile(const_cast(primitive)); - case schema::PrimitiveType_TopK: - return new TopK(const_cast(primitive)); - case schema::PrimitiveType_Unique: - return new Unique(const_cast(primitive)); - case schema::PrimitiveType_Unstack: - return new Unstack(const_cast(primitive)); - case schema::PrimitiveType_ReverseSequence: - return new ReverseSequence(const_cast(primitive)); - case schema::PrimitiveType_Round: - return new Round(const_cast(primitive)); - case schema::PrimitiveType_ZerosLike: - return new ZerosLike(const_cast(primitive)); - case schema::PrimitiveType_Where: - return new Where(const_cast(primitive)); + case schema::PrimitiveType_Equal: + return new Equal(const_cast(primitive)); + case schema::PrimitiveType_NotEqual: + return new NotEqual(const_cast(primitive)); + case schema::PrimitiveType_Less: + return new Less(const_cast(primitive)); + case schema::PrimitiveType_LessEqual: + return new LessEqual(const_cast(primitive)); + case schema::PrimitiveType_Greater: + return new Greater(const_cast(primitive)); + case schema::PrimitiveType_GreaterEqual: + return new GreaterEqual(const_cast(primitive)); case schema::PrimitiveType_Floor: return new Floor(const_cast(primitive)); - case schema::PrimitiveType_Shape: - return new Shape(const_cast(primitive)); - case schema::PrimitiveType_ScatterND: - return new ScatterND(const_cast(primitive)); - case schema::PrimitiveType_Unsqueeze: - return new Unsqueeze(const_cast(primitive)); - case schema::PrimitiveType_Flatten: - return new Flatten(const_cast(primitive)); - case schema::PrimitiveType_StridedSlice: - return new StridedSlice(const_cast(primitive)); - case schema::PrimitiveType_Resize: - return new Resize(const_cast(primitive)); + case schema::PrimitiveType_Ceil: + return new Ceil(const_cast(primitive)); + case schema::PrimitiveType_Split: + return new Split(const_cast(primitive)); case schema::PrimitiveType_OneHot: return new OneHot(const_cast(primitive)); - case schema::PrimitiveType_PriorBox: - return new PriorBox(const_cast(primitive)); case schema::PrimitiveType_SpaceToDepth: return new SpaceToDepth(const_cast(primitive)); - case schema::PrimitiveType_SpaceToBatch: - return new SpaceToBatch(const_cast(primitive)); - case schema::PrimitiveType_QuantDTypeCast: - return new QuantDTypeCast(const_cast(primitive)); + case schema::PrimitiveType_Tile: + return new Tile(const_cast(primitive)); + case schema::PrimitiveType_Resize: + return new Resize(const_cast(primitive)); + case schema::PrimitiveType_Unstack: + return new Unstack(const_cast(primitive)); + case schema::PrimitiveType_Unique: + return new Unique(const_cast(primitive)); + case schema::PrimitiveType_TopK: + return new TopK(const_cast(primitive)); case schema::PrimitiveType_MatMul: return new MatMul(const_cast(primitive)); + case schema::PrimitiveType_QuantDTypeCast: + return new QuantDTypeCast(const_cast(primitive)); case schema::PrimitiveType_EmbeddingLookup: return new EmbeddingLookup(const_cast(primitive)); - case schema::PrimitiveType_ConstantOfShape: - return new ConstantOfShape(const_cast(primitive)); + case schema::PrimitiveType_Elu: + return new Elu(const_cast(primitive)); + case schema::PrimitiveType_DeDepthwiseConv2D: + return new DeDepthwiseConv2D(const_cast(primitive)); + case schema::PrimitiveType_Shape: + return new Shape(const_cast(primitive)); + case schema::PrimitiveType_Unsqueeze: + return new Unsqueeze(const_cast(primitive)); default: break; } return nullptr; } - #endif int PrimitiveC::Type() const { diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index 80474b5d57b..86f7c3302f6 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -41,25 +41,32 @@ constexpr uint32_t kDimension_4d = 4; const std::set kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeFloat16}; #ifdef PRIMITIVE_WRITEABLE +constexpr int kAnfPopulaterOne = 1; +constexpr int kAnfPopulaterTwo = 2; +constexpr int kAnfPopulaterThree = 3; class PrimitiveC : public mindspore::Primitive { public: - explicit PrimitiveC(schema::Primitive *primitive) : Primitive("") { this->primitive_ = primitive->UnPack(); } - + // Argument primitive is delived into PrimitiveC and will be deleted in ~PrimitiveC(). Caller should not delete + // primitive explicit PrimitiveC(schema::PrimitiveT *primitive) : Primitive(""), primitive_(primitive) {} explicit PrimitiveC(const Primitive &prim) : Primitive(prim) {} + // Argument primitive is delived into PrimitiveC and will be deleted in ~PrimitiveC(). Caller should not delete + // primitive explicit PrimitiveC(const std::string &name, schema::PrimitiveT *primitive) : Primitive(name), primitive_(primitive) {} + PrimitiveC() : Primitive(""), primitive_(nullptr) {} + MS_DECLARE_PARENT(PrimitiveC, Primitive); - ~PrimitiveC() override = default; + ~PrimitiveC() override { + // delete this->primitive_; + } int Type() const; - // static PrimitiveC *UnPackFromPrimitive(const Primitive &prim); - schema::PrimitiveT *GetPrimitiveT() const; void SetPrimitiveT(schema::PrimitiveT *prim); @@ -99,10 +106,16 @@ class PrimitiveC : public mindspore::Primitive { void SetInferFlag(bool flag); - static PrimitiveC *CreatePrimitive(mindspore::schema::Primitive *primitive); + static PrimitiveC *UnPackFromSchemaPrimitive(mindspore::schema::Primitive *primitive) { + return UnPackFromSchemaPrimitiveT(primitive->UnPack()); + } + + static PrimitiveC *UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT *primitive); + + static std::shared_ptr UnPackFromPrimitive(const Primitive &prim, const std::vector &inputs); protected: - // virutal PrimitiveC *UnPackAttr(const Primitive &prim) = 0; + virtual int UnPackAttr(const Primitive &prim) { return RET_ERROR; } protected: schema::PrimitiveT *primitive_ = nullptr; @@ -117,16 +130,20 @@ std::shared_ptr GetMakeTuplePrim(); std::shared_ptr GetTupleGetItemPrim(); - - #else class PrimitiveC { public: PrimitiveC() = default; + // Argument primitive is delived into PrimitiveC and will be deleted in ~PrimitiveC(). Caller should not delete + // primitive explicit PrimitiveC(schema::Primitive *primitive) : primitive_(primitive) {} - virtual ~PrimitiveC() = default; + virtual ~PrimitiveC() { + // delete this->primitive_; + } + + static PrimitiveC *UnPackFromSchemaPrimitive(mindspore::schema::Primitive *primitive); bool GetInferFlag() const; diff --git a/mindspore/lite/src/ops/prior_box.h b/mindspore/lite/src/ops/prior_box.h index 508cd88659a..6802a74479d 100644 --- a/mindspore/lite/src/ops/prior_box.h +++ b/mindspore/lite/src/ops/prior_box.h @@ -28,10 +28,11 @@ namespace lite { class PriorBox : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + PriorBox() = default; explicit PriorBox(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit PriorBox(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetMinSizes() const; std::vector GetMaxSizes() const; diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.h b/mindspore/lite/src/ops/quant.cc similarity index 54% rename from mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.h rename to mindspore/lite/src/ops/quant.cc index 6abd8e907ba..ff824ad605c 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.h +++ b/mindspore/lite/src/ops/quant.cc @@ -14,19 +14,21 @@ * limitations under the License. */ -#ifndef MINDSPORE_ANF_CONCAT_PARSER_H -#define MINDSPORE_ANF_CONCAT_PARSER_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" +#include "src/ops/quant.h" #include +#include -namespace mindspore::lite { -class AnfConcatPopulater : public AnfNodePopulater { - public: - AnfConcatPopulater() = default; - ~AnfConcatPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtrr, - const std::vector &inputs) override; -}; -} // namespace mindspore::lite +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int Quant::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + this->primitive_ = new (schema::PrimitiveT); + auto attr = std::make_unique(); + this->primitive_->value.type = schema::PrimitiveType_OnnxInt8Quantize; + this->primitive_->value.value = attr.release(); -#endif // MINDSPORE_ANF_CONCAT_PARSER_H + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/quant.h b/mindspore/lite/src/ops/quant.h new file mode 100644 index 00000000000..6ba178e9fa0 --- /dev/null +++ b/mindspore/lite/src/ops/quant.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_SRC_OPS_QUANT_H_ +#define LITE_MINDSPORE_LITE_SRC_OPS_QUANT_H_ +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class Quant : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + Quant() = default; + explicit Quant(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs); +#else + explicit Quant(schema::Primitive *primitive) : PrimitiveC(primitive) {} +#endif +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_SRC_OPS_QUANT_H_ diff --git a/mindspore/lite/src/ops/quant_dtype_cast.h b/mindspore/lite/src/ops/quant_dtype_cast.h index dbbb689d2fb..718019d1e37 100644 --- a/mindspore/lite/src/ops/quant_dtype_cast.h +++ b/mindspore/lite/src/ops/quant_dtype_cast.h @@ -28,10 +28,12 @@ namespace lite { class QuantDTypeCast : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + QuantDTypeCast() = default; explicit QuantDTypeCast(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif + MS_DECLARE_PARENT(QuantDTypeCast, PrimitiveC); +#else explicit QuantDTypeCast(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetSrcT() const; int GetDstT() const; diff --git a/mindspore/lite/src/ops/range.h b/mindspore/lite/src/ops/range.h index b7abbc102d0..d1e5a13c1eb 100644 --- a/mindspore/lite/src/ops/range.h +++ b/mindspore/lite/src/ops/range.h @@ -28,10 +28,11 @@ namespace lite { class Range : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Range() = default; explicit Range(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Range(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetDType() const; int GetStart() const; diff --git a/mindspore/lite/src/ops/rank.h b/mindspore/lite/src/ops/rank.h index b56f4560327..f2f39c2598d 100644 --- a/mindspore/lite/src/ops/rank.h +++ b/mindspore/lite/src/ops/rank.h @@ -28,10 +28,11 @@ namespace lite { class Rank : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Rank() = default; explicit Rank(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Rank(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite diff --git a/mindspore/lite/src/ops/reduce.cc b/mindspore/lite/src/ops/reduce.cc index 2bd709a5007..a0def401391 100644 --- a/mindspore/lite/src/ops/reduce.cc +++ b/mindspore/lite/src/ops/reduce.cc @@ -15,6 +15,7 @@ */ #include "src/ops/reduce.h" +#include namespace mindspore { namespace lite { @@ -27,6 +28,38 @@ void Reduce::SetAxes(const std::vector &axes) { this->primitive_->value.AsR void Reduce::SetKeepDims(int keep_dims) { this->primitive_->value.AsReduce()->keepDims = keep_dims; } void Reduce::SetMode(int mode) { this->primitive_->value.AsReduce()->mode = (schema::ReduceMode)mode; } +int Reduce::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + this->primitive_ = new (schema::PrimitiveT); + auto attr = std::make_unique(); + attr->mode = schema::ReduceMode_ReduceMean; + + attr->keepDims = GetValue(prim.GetAttr("keep_dims")); + if (inputs.size() == kAnfPopulaterTwo) { + auto inputNode = inputs[kAnfPopulaterOne]; + MS_ASSERT(inputNode != nullptr); + if (inputNode->isa()) { + auto valueNode = inputNode->cast(); + MS_ASSERT(valueNode != nullptr); + auto value = valueNode->value(); + MS_ASSERT(value != nullptr); + if (value->isa()) { + auto valTuplPtr = dyn_cast(value); + MS_ASSERT(valTuplPtr != nullptr); + for (size_t i = 0; i < valTuplPtr->size(); i++) { + auto elem = dyn_cast((*valTuplPtr)[i]); + MS_ASSERT(elem != nullptr); + attr->axes.emplace_back(elem->value()); + } + } + } + } + + this->primitive_->value.type = schema::PrimitiveType_Reduce; + this->primitive_->value.value = attr.release(); + + return RET_OK; +} + #else std::vector Reduce::GetAxes() const { diff --git a/mindspore/lite/src/ops/reduce.h b/mindspore/lite/src/ops/reduce.h index 079834742c1..afccbf9cd8b 100644 --- a/mindspore/lite/src/ops/reduce.h +++ b/mindspore/lite/src/ops/reduce.h @@ -28,10 +28,12 @@ namespace lite { class Reduce : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Reduce() = default; explicit Reduce(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif + int UnPackAttr(const Primitive &prim, const std::vector &inputs); +#else explicit Reduce(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetAxes() const; int GetKeepDims() const; diff --git a/mindspore/lite/src/ops/reshape.cc b/mindspore/lite/src/ops/reshape.cc index 1694d3154da..9bc492f5f26 100644 --- a/mindspore/lite/src/ops/reshape.cc +++ b/mindspore/lite/src/ops/reshape.cc @@ -15,6 +15,7 @@ */ #include "src/ops/reshape.h" +#include #include #include "include/errorcode.h" #include "utils/log_adapter.h" @@ -28,6 +29,32 @@ std::vector Reshape::GetShape() const { return this->primitive_->value.AsR void Reshape::SetFormat(int format) { this->primitive_->value.AsReshape()->format = (schema::Format)format; } void Reshape::SetShape(const std::vector &shape) { this->primitive_->value.AsReshape()->shape = shape; } +int Reshape::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + this->primitive_ = new (schema::PrimitiveT); + auto attr = std::make_unique(); + MS_ASSERT(inputs.size() == kAnfPopulaterThree - 1); + auto inputNode = inputs[kAnfPopulaterTwo - 1]; + if (inputNode->isa()) { + auto valueNode = inputNode->cast(); + MS_ASSERT(valueNode != nullptr); + auto val = valueNode->value(); + MS_ASSERT(val != nullptr); + if (val->isa()) { + auto tuple = val->cast(); + MS_ASSERT(tuple != nullptr); + for (size_t i = 0; i < tuple->size(); ++i) { + auto elem = tuple->value()[i]->cast(); + MS_ASSERT(elem != nullptr); + attr->shape.emplace_back(static_cast(elem->value())); + } + } + } + + this->primitive_->value.type = schema::PrimitiveType_Reshape; + this->primitive_->value.value = attr.release(); + + return RET_OK; +} #else diff --git a/mindspore/lite/src/ops/reshape.h b/mindspore/lite/src/ops/reshape.h index 0343d20bd39..8d45de76bde 100644 --- a/mindspore/lite/src/ops/reshape.h +++ b/mindspore/lite/src/ops/reshape.h @@ -28,10 +28,12 @@ namespace lite { class Reshape : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Reshape() = default; explicit Reshape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif + int UnPackAttr(const Primitive &prim, const std::vector &inputs); +#else explicit Reshape(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; std::vector GetShape() const; diff --git a/mindspore/lite/src/ops/resize.h b/mindspore/lite/src/ops/resize.h index 697aa360081..e92e6a0f315 100644 --- a/mindspore/lite/src/ops/resize.h +++ b/mindspore/lite/src/ops/resize.h @@ -28,10 +28,11 @@ namespace lite { class Resize : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Resize() = default; explicit Resize(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Resize(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; int GetMethod() const; diff --git a/mindspore/lite/src/ops/reverse.h b/mindspore/lite/src/ops/reverse.h index 9ac0ec7ce40..2b202e112d4 100644 --- a/mindspore/lite/src/ops/reverse.h +++ b/mindspore/lite/src/ops/reverse.h @@ -28,10 +28,11 @@ namespace lite { class Reverse : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Reverse() = default; explicit Reverse(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Reverse(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif std::vector GetAxis() const; void SetAxis(const std::vector &axis); }; diff --git a/mindspore/lite/src/ops/reverse_sequence.h b/mindspore/lite/src/ops/reverse_sequence.h index 67e543a9d0f..66624bf8a5c 100644 --- a/mindspore/lite/src/ops/reverse_sequence.h +++ b/mindspore/lite/src/ops/reverse_sequence.h @@ -28,10 +28,11 @@ namespace lite { class ReverseSequence : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + ReverseSequence() = default; explicit ReverseSequence(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit ReverseSequence(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetSeqAxis() const; int GetBatchAxis() const; diff --git a/mindspore/lite/src/ops/roi_pooling.h b/mindspore/lite/src/ops/roi_pooling.h index f3085f43898..d02720394a9 100644 --- a/mindspore/lite/src/ops/roi_pooling.h +++ b/mindspore/lite/src/ops/roi_pooling.h @@ -28,10 +28,11 @@ namespace lite { class ROIPooling : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + ROIPooling() = default; explicit ROIPooling(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit ROIPooling(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetPooledH() const; int GetPooledW() const; diff --git a/mindspore/lite/src/ops/round.h b/mindspore/lite/src/ops/round.h index b9c1fef1f1a..3e6496555eb 100644 --- a/mindspore/lite/src/ops/round.h +++ b/mindspore/lite/src/ops/round.h @@ -28,9 +28,11 @@ namespace lite { class Round : public ArithmeticSelf { public: #ifdef PRIMITIVE_WRITEABLE + Round() = default; explicit Round(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} -#endif +#else explicit Round(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/rsqrt.h b/mindspore/lite/src/ops/rsqrt.h index 396ca2e187e..58a32ffd9b2 100644 --- a/mindspore/lite/src/ops/rsqrt.h +++ b/mindspore/lite/src/ops/rsqrt.h @@ -28,9 +28,11 @@ namespace lite { class Rsqrt : public ArithmeticSelf { public: #ifdef PRIMITIVE_WRITEABLE + Rsqrt() = default; explicit Rsqrt(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} -#endif +#else explicit Rsqrt(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/scale.h b/mindspore/lite/src/ops/scale.h index 10e68c80c1a..df1ebc5c86c 100644 --- a/mindspore/lite/src/ops/scale.h +++ b/mindspore/lite/src/ops/scale.h @@ -28,10 +28,11 @@ namespace lite { class Scale : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Scale() = default; explicit Scale(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Scale(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int GetAxis() const; void SetAxis(int axis); }; diff --git a/mindspore/lite/src/ops/scatter_nd.h b/mindspore/lite/src/ops/scatter_nd.h index 83cdf9fd0a9..69ec5dc301d 100644 --- a/mindspore/lite/src/ops/scatter_nd.h +++ b/mindspore/lite/src/ops/scatter_nd.h @@ -28,10 +28,11 @@ namespace lite { class ScatterND : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + ScatterND() = default; explicit ScatterND(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit ScatterND(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite diff --git a/mindspore/lite/src/ops/shape.h b/mindspore/lite/src/ops/shape.h index 1178f4d9cca..ae6e1ceec99 100644 --- a/mindspore/lite/src/ops/shape.h +++ b/mindspore/lite/src/ops/shape.h @@ -28,10 +28,11 @@ namespace lite { class Shape : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Shape() = default; explicit Shape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Shape(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite diff --git a/mindspore/lite/src/ops/sin.h b/mindspore/lite/src/ops/sin.h index 0e48f50178b..ae410da36ca 100644 --- a/mindspore/lite/src/ops/sin.h +++ b/mindspore/lite/src/ops/sin.h @@ -28,9 +28,11 @@ namespace lite { class Sin : public ArithmeticSelf { public: #ifdef PRIMITIVE_WRITEABLE + Sin() = default; explicit Sin(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} -#endif +#else explicit Sin(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/slice.cc b/mindspore/lite/src/ops/slice.cc index 13008d0c937..11b984fc947 100644 --- a/mindspore/lite/src/ops/slice.cc +++ b/mindspore/lite/src/ops/slice.cc @@ -26,32 +26,32 @@ constexpr int kSliceInputNum = 1; constexpr int kSliceOutputNum = 1; } // namespace #ifdef PRIMITIVE_WRITEABLE -int SliceOp::GetFormat() const { return this->primitive_->value.AsSlice()->format; } -std::vector SliceOp::GetBegin() const { return this->primitive_->value.AsSlice()->begin; } -std::vector SliceOp::GetSize() const { return this->primitive_->value.AsSlice()->size; } +int Slice::GetFormat() const { return this->primitive_->value.AsSlice()->format; } +std::vector Slice::GetBegin() const { return this->primitive_->value.AsSlice()->begin; } +std::vector Slice::GetSize() const { return this->primitive_->value.AsSlice()->size; } -void SliceOp::SetFormat(int format) { this->primitive_->value.AsSlice()->format = (schema::Format)format; } -void SliceOp::SetBegin(const std::vector &begin) { this->primitive_->value.AsSlice()->begin = begin; } -void SliceOp::SetSize(const std::vector &size) { this->primitive_->value.AsSlice()->size = size; } +void Slice::SetFormat(int format) { this->primitive_->value.AsSlice()->format = (schema::Format)format; } +void Slice::SetBegin(const std::vector &begin) { this->primitive_->value.AsSlice()->begin = begin; } +void Slice::SetSize(const std::vector &size) { this->primitive_->value.AsSlice()->size = size; } #else -int SliceOp::GetFormat() const { return this->primitive_->value_as_Slice()->format(); } -std::vector SliceOp::GetBegin() const { +int Slice::GetFormat() const { return this->primitive_->value_as_Slice()->format(); } +std::vector Slice::GetBegin() const { auto fb_vector = this->primitive_->value_as_Slice()->begin(); return std::vector(fb_vector->begin(), fb_vector->end()); } -std::vector SliceOp::GetSize() const { +std::vector Slice::GetSize() const { auto fb_vector = this->primitive_->value_as_Slice()->size(); return std::vector(fb_vector->begin(), fb_vector->end()); } -void SliceOp::SetFormat(int format) {} -void SliceOp::SetBegin(const std::vector &begin) {} -void SliceOp::SetSize(const std::vector &size) {} +void Slice::SetFormat(int format) {} +void Slice::SetBegin(const std::vector &begin) {} +void Slice::SetSize(const std::vector &size) {} #endif -int SliceOp::InferShape(std::vector inputs, std::vector outputs) { +int Slice::InferShape(std::vector inputs, std::vector outputs) { MS_ASSERT(this->primitive_ != nullptr); if (inputs.size() != kSliceInputNum || outputs.size() != kSliceOutputNum) { MS_LOG(ERROR) << "input size:" << inputs.size() << ",output size:" << outputs.size(); diff --git a/mindspore/lite/src/ops/slice.h b/mindspore/lite/src/ops/slice.h index 769de4dbd3a..71b4dc5f3e9 100644 --- a/mindspore/lite/src/ops/slice.h +++ b/mindspore/lite/src/ops/slice.h @@ -25,13 +25,14 @@ namespace mindspore { namespace lite { -class SliceOp : public PrimitiveC { +class Slice : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE - explicit SliceOp(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + Slice() = default; + explicit Slice(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#else + explicit Slice(schema::Primitive *primitive) : PrimitiveC(primitive) {} #endif - explicit SliceOp(schema::Primitive *primitive) : PrimitiveC(primitive) {} - int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; std::vector GetBegin() const; diff --git a/mindspore/lite/src/ops/softmax.h b/mindspore/lite/src/ops/softmax.h index a77e552eb00..41e11ae0445 100644 --- a/mindspore/lite/src/ops/softmax.h +++ b/mindspore/lite/src/ops/softmax.h @@ -28,10 +28,11 @@ namespace lite { class SoftMax : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + SoftMax() = default; explicit SoftMax(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit SoftMax(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; void SetAxis(int axis); diff --git a/mindspore/lite/src/ops/softmax_cross_entropy.h b/mindspore/lite/src/ops/softmax_cross_entropy.h index 5bd160b8e17..169b966cd93 100644 --- a/mindspore/lite/src/ops/softmax_cross_entropy.h +++ b/mindspore/lite/src/ops/softmax_cross_entropy.h @@ -28,10 +28,11 @@ namespace lite { class SoftmaxCrossEntropy : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + SoftmaxCrossEntropy() = default; explicit SoftmaxCrossEntropy(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit SoftmaxCrossEntropy(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif std::vector GetAxis() const; void SetAxis(const std::vector &axis); }; diff --git a/mindspore/lite/src/ops/space_to_batch.h b/mindspore/lite/src/ops/space_to_batch.h index e5afd6bb6c9..3fb81398afa 100644 --- a/mindspore/lite/src/ops/space_to_batch.h +++ b/mindspore/lite/src/ops/space_to_batch.h @@ -28,10 +28,11 @@ namespace lite { class SpaceToBatch : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + SpaceToBatch() = default; explicit SpaceToBatch(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit SpaceToBatch(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetBlockShape() const; std::vector GetPaddings() const; diff --git a/mindspore/lite/src/ops/space_to_batch_nd.h b/mindspore/lite/src/ops/space_to_batch_nd.h index 821ea1c2ea7..4cccdaeb6ae 100644 --- a/mindspore/lite/src/ops/space_to_batch_nd.h +++ b/mindspore/lite/src/ops/space_to_batch_nd.h @@ -28,10 +28,11 @@ namespace lite { class SpaceToBatchND : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + SpaceToBatchND() = default; explicit SpaceToBatchND(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit SpaceToBatchND(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif std::vector GetBlockShape() const; std::vector GetPaddings() const; void SetBlockShape(const std::vector &block_shape); diff --git a/mindspore/lite/src/ops/space_to_depth.h b/mindspore/lite/src/ops/space_to_depth.h index 9f374f3b8de..cd888825d62 100644 --- a/mindspore/lite/src/ops/space_to_depth.h +++ b/mindspore/lite/src/ops/space_to_depth.h @@ -28,10 +28,11 @@ namespace lite { class SpaceToDepth : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + SpaceToDepth() = default; explicit SpaceToDepth(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit SpaceToDepth(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetBlockSize() const; int GetFormat() const; diff --git a/mindspore/lite/src/ops/sparse_to_dense.h b/mindspore/lite/src/ops/sparse_to_dense.h index c35663853de..40c8798bd79 100644 --- a/mindspore/lite/src/ops/sparse_to_dense.h +++ b/mindspore/lite/src/ops/sparse_to_dense.h @@ -28,10 +28,11 @@ namespace lite { class SparseToDense : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + SparseToDense() = default; explicit SparseToDense(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit SparseToDense(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif std::vector GetOutputShape() const; std::vector GetSparseValue() const; std::vector GetDefaultValue() const; diff --git a/mindspore/lite/src/ops/split.h b/mindspore/lite/src/ops/split.h index 433d3259f79..d8521329c30 100644 --- a/mindspore/lite/src/ops/split.h +++ b/mindspore/lite/src/ops/split.h @@ -28,10 +28,11 @@ namespace lite { class Split : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Split() = default; explicit Split(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Split(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetNumberSplit() const; std::vector GetSizeSplits() const; diff --git a/mindspore/lite/src/ops/sqrt.h b/mindspore/lite/src/ops/sqrt.h index 88ceafa07eb..75202fd2537 100644 --- a/mindspore/lite/src/ops/sqrt.h +++ b/mindspore/lite/src/ops/sqrt.h @@ -28,9 +28,11 @@ namespace lite { class Sqrt : public ArithmeticSelf { public: #ifdef PRIMITIVE_WRITEABLE + Sqrt() = default; explicit Sqrt(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} -#endif +#else explicit Sqrt(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/square.h b/mindspore/lite/src/ops/square.h index dca72cfa5ee..52a8fa00af3 100644 --- a/mindspore/lite/src/ops/square.h +++ b/mindspore/lite/src/ops/square.h @@ -27,9 +27,11 @@ namespace lite { class Square : public ArithmeticSelf { public: #ifdef PRIMITIVE_WRITEABLE + Square() = default; explicit Square(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} -#endif +#else explicit Square(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/squared_difference.h b/mindspore/lite/src/ops/squared_difference.h index e1d88db06cb..e625f8a2fb1 100644 --- a/mindspore/lite/src/ops/squared_difference.h +++ b/mindspore/lite/src/ops/squared_difference.h @@ -28,9 +28,11 @@ namespace lite { class SquaredDifference : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE + SquaredDifference() = default; explicit SquaredDifference(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} -#endif +#else explicit SquaredDifference(schema::Primitive *primitive) : Arithmetic(primitive) {} +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/squeeze.h b/mindspore/lite/src/ops/squeeze.h index 7ca9adbcc54..30b8968b8f8 100644 --- a/mindspore/lite/src/ops/squeeze.h +++ b/mindspore/lite/src/ops/squeeze.h @@ -28,10 +28,11 @@ namespace lite { class Squeeze : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Squeeze() = default; explicit Squeeze(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Squeeze(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetAxis() const; void SetAxis(const std::vector &axis); diff --git a/mindspore/lite/src/ops/stack.cc b/mindspore/lite/src/ops/stack.cc index 4a1d5f02736..c7e13175a20 100644 --- a/mindspore/lite/src/ops/stack.cc +++ b/mindspore/lite/src/ops/stack.cc @@ -84,8 +84,8 @@ int Stack::InferShape(std::vector inputs, std::vectordata_type() != input0_data_type) { - MS_LOG(ERROR) << "All input shuld have the same data type!input[" << i << "] data type = " - << inputs[i]->data_type(); + MS_LOG(ERROR) << "All input shuld have the same data type!input[" << i + << "] data type = " << inputs[i]->data_type(); return RET_PARAM_INVALID; } } diff --git a/mindspore/lite/src/ops/stack.h b/mindspore/lite/src/ops/stack.h index 7f3a2725f40..37930cb4341 100644 --- a/mindspore/lite/src/ops/stack.h +++ b/mindspore/lite/src/ops/stack.h @@ -28,10 +28,11 @@ namespace lite { class Stack : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Stack() = default; explicit Stack(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Stack(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; int GetN() const; diff --git a/mindspore/lite/src/ops/strided_slice.h b/mindspore/lite/src/ops/strided_slice.h index 66df8b29e35..811d8bac155 100644 --- a/mindspore/lite/src/ops/strided_slice.h +++ b/mindspore/lite/src/ops/strided_slice.h @@ -28,10 +28,11 @@ namespace lite { class StridedSlice : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + StridedSlice() = default; explicit StridedSlice(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit StridedSlice(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetBeginMask() const; int GetEndMask() const; diff --git a/mindspore/lite/src/ops/sub.h b/mindspore/lite/src/ops/sub.h index 2738d183978..1f6d90c9faa 100644 --- a/mindspore/lite/src/ops/sub.h +++ b/mindspore/lite/src/ops/sub.h @@ -28,10 +28,11 @@ namespace lite { class Sub : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE + Sub() = default; explicit Sub(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} -#endif +#else explicit Sub(schema::Primitive *primitive) : Arithmetic(primitive) {} - +#endif int GetActivationType() const; void SetActivationType(int activation_type); }; diff --git a/mindspore/lite/src/ops/tile.h b/mindspore/lite/src/ops/tile.h index dfc025e5af0..187129485e7 100644 --- a/mindspore/lite/src/ops/tile.h +++ b/mindspore/lite/src/ops/tile.h @@ -28,10 +28,11 @@ namespace lite { class Tile : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Tile() = default; explicit Tile(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Tile(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetMultiples() const; void SetMultiples(const std::vector &multiples); diff --git a/mindspore/lite/src/ops/topk.h b/mindspore/lite/src/ops/topk.h index 7b58ecaab2f..1c23040537c 100644 --- a/mindspore/lite/src/ops/topk.h +++ b/mindspore/lite/src/ops/topk.h @@ -28,10 +28,11 @@ namespace lite { class TopK : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + TopK() = default; explicit TopK(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit TopK(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetK() const; bool GetSorted() const; diff --git a/mindspore/lite/src/ops/transpose.cc b/mindspore/lite/src/ops/transpose.cc index ef6750015f2..f69e5c37b70 100644 --- a/mindspore/lite/src/ops/transpose.cc +++ b/mindspore/lite/src/ops/transpose.cc @@ -15,6 +15,7 @@ */ #include "src/ops/transpose.h" +#include #include "include/errorcode.h" #include "utils/log_adapter.h" @@ -27,6 +28,32 @@ bool Transpose::GetConjugate() const { return this->primitive_->value.AsTranspos void Transpose::SetPerm(const std::vector &perm) { this->primitive_->value.AsTranspose()->perm = perm; } void Transpose::SetConjugate(bool conjugate) { this->primitive_->value.AsTranspose()->conjugate = conjugate; } +int Transpose::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + this->primitive_ = new (schema::PrimitiveT); + auto attr = std::make_unique(); + MS_ASSERT(inputs.size() == kAnfPopulaterTwo); + auto inputNode = inputs[kAnfPopulaterOne]; + if (inputNode->isa()) { + auto valNode = inputNode->cast(); + MS_ASSERT(valNode != nullptr); + auto val = valNode->value(); + MS_ASSERT(val != nullptr); + if (val->isa()) { + auto tuple = val->cast(); + MS_ASSERT(tuple != nullptr); + for (size_t i = 0; i < tuple->size(); i++) { + auto elem = tuple->value()[i]->cast(); + MS_ASSERT(elem != nullptr); + attr->perm.emplace_back(static_cast(elem->value())); + } + } + } + + this->primitive_->value.type = schema::PrimitiveType_Transpose; + this->primitive_->value.value = attr.release(); + return RET_OK; +} + #else std::vector Transpose::GetPerm() const { diff --git a/mindspore/lite/src/ops/transpose.h b/mindspore/lite/src/ops/transpose.h index be11939bec2..4cde724030e 100644 --- a/mindspore/lite/src/ops/transpose.h +++ b/mindspore/lite/src/ops/transpose.h @@ -28,10 +28,12 @@ namespace lite { class Transpose : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Transpose() = default; explicit Transpose(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif + int UnPackAttr(const Primitive &prim, const std::vector &inputs); +#else explicit Transpose(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetPerm() const; bool GetConjugate() const; diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.h b/mindspore/lite/src/ops/tuple_get_item.cc similarity index 54% rename from mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.h rename to mindspore/lite/src/ops/tuple_get_item.cc index 20f55f8da60..1abf6b18a23 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.h +++ b/mindspore/lite/src/ops/tuple_get_item.cc @@ -13,19 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_ANF_DEQUANT_PARSER_H -#define MINDSPORE_ANF_DEQUANT_PARSER_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" + +#include "src/ops/tuple_get_item.h" #include +#include -namespace mindspore::lite { -class AnfDequantPopulater : public AnfNodePopulater { - public: - AnfDequantPopulater() = default; - ~AnfDequantPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) override; -}; -} // namespace mindspore::lite +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int TupleGetItem::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + this->primitive_ = new (schema::PrimitiveT); + auto attr = std::make_unique(); + this->primitive_->value.type = schema::PrimitiveType_TupleGetItem; + this->primitive_->value.value = attr.release(); -#endif // MINDSPORE_ANF_DEQUANT_PARSER_H + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/tuple_get_item.h b/mindspore/lite/src/ops/tuple_get_item.h new file mode 100644 index 00000000000..729a1cfc9b3 --- /dev/null +++ b/mindspore/lite/src/ops/tuple_get_item.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_SRC_OPS_TUPLE_GET_ITEM_H_ +#define LITE_MINDSPORE_LITE_SRC_OPS_TUPLE_GET_ITEM_H_ + +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class TupleGetItem : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + TupleGetItem() = default; + explicit TupleGetItem(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs); +#else + explicit TupleGetItem(schema::Primitive *primitive) : PrimitiveC(primitive) {} +#endif +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_SRC_OPS_TUPLE_GET_ITEM_H_ diff --git a/mindspore/lite/src/ops/unique.h b/mindspore/lite/src/ops/unique.h index c623ab89e97..c8ca722abf6 100644 --- a/mindspore/lite/src/ops/unique.h +++ b/mindspore/lite/src/ops/unique.h @@ -28,10 +28,11 @@ namespace lite { class Unique : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Unique() = default; explicit Unique(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Unique(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetOutType() const; void SetOutType(int out_type); diff --git a/mindspore/lite/src/ops/unsqueeze.h b/mindspore/lite/src/ops/unsqueeze.h index 1873feaa678..bea5831fcd8 100644 --- a/mindspore/lite/src/ops/unsqueeze.h +++ b/mindspore/lite/src/ops/unsqueeze.h @@ -28,10 +28,11 @@ namespace lite { class Unsqueeze : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Unsqueeze() = default; explicit Unsqueeze(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Unsqueeze(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetAxis() const; void SetAxis(const std::vector &axis); diff --git a/mindspore/lite/src/ops/unstack.h b/mindspore/lite/src/ops/unstack.h index 8c7b357e5dd..337f74ab7f1 100644 --- a/mindspore/lite/src/ops/unstack.h +++ b/mindspore/lite/src/ops/unstack.h @@ -28,10 +28,11 @@ namespace lite { class Unstack : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Unstack() = default; explicit Unstack(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Unstack(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetNum() const; int GetAxis() const; diff --git a/mindspore/lite/src/ops/upsample.h b/mindspore/lite/src/ops/upsample.h index 61f0869ae74..26df8d7604b 100644 --- a/mindspore/lite/src/ops/upsample.h +++ b/mindspore/lite/src/ops/upsample.h @@ -29,10 +29,11 @@ namespace lite { class Upsample : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Upsample() = default; explicit Upsample(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Upsample(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif std::string GetMode() const; std::vector GetScales() const; void SetMode(std::string mode); diff --git a/mindspore/lite/src/ops/where.h b/mindspore/lite/src/ops/where.h index 9279a147b86..7db38f9a5d9 100644 --- a/mindspore/lite/src/ops/where.h +++ b/mindspore/lite/src/ops/where.h @@ -28,10 +28,11 @@ namespace lite { class Where : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + Where() = default; explicit Where(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit Where(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetCondition() const; void SetCondition(const std::vector &condition); diff --git a/mindspore/lite/src/ops/zeros_like.h b/mindspore/lite/src/ops/zeros_like.h index a509fa1b021..36524220df3 100644 --- a/mindspore/lite/src/ops/zeros_like.h +++ b/mindspore/lite/src/ops/zeros_like.h @@ -28,10 +28,11 @@ namespace lite { class ZerosLike : public PrimitiveC { public: #ifdef PRIMITIVE_WRITEABLE + ZerosLike() = default; explicit ZerosLike(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#endif +#else explicit ZerosLike(schema::Primitive *primitive) : PrimitiveC(primitive) {} - +#endif int InferShape(std::vector inputs_, std::vector outputs_) override; }; } // namespace lite diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 33feaaba287..33621ea6151 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -14,6 +14,10 @@ * limitations under the License. */ +#include "src/populate_parameter.h" +#include "src/ops/primitive_c.h" +#include "utils/log_adapter.h" +#include "schema/ops_generated.h" #include "src/ops/constant_of_shape.h" #include "src/ops/space_to_batch.h" #include "src/ops/conv2d.h" @@ -106,18 +110,14 @@ #include "src/ops/squared_difference.h" #include "src/ops/ceil.h" #include "src/ops/round.h" -#include "src/ops/primitive_c.h" -#include "src/populate_parameter.h" -#include "utils/log_adapter.h" -#include "schema/ops_generated.h" #include "nnacl/op_base.h" #include "nnacl/fp32/arg_min_max.h" #include "nnacl/fp32/cast.h" #include "nnacl/concat_parameter.h" -#include "nnacl/prelu_parameter.h" #include "nnacl/fp32/slice.h" #include "nnacl/fp32/broadcast_to.h" #include "nnacl/reshape_parameter.h" +#include "nnacl/prelu_parameter.h" #include "nnacl/shape.h" #include "nnacl/fp32/constant_of_shape.h" #include "nnacl/fp32/stack.h" @@ -154,11 +154,11 @@ #include "nnacl/scatter_nd.h" #include "nnacl/batch_to_space.h" #include "nnacl/fp32/crop.h" -#include "src/runtime/kernel/arm/fp32/flatten.h" +#include "fp32/flatten.h" #include "nnacl/fp32/unsqueeze.h" #include "nnacl/fp32/one_hot.h" #include "nnacl/strided_slice.h" -#include "src/runtime/kernel/arm/base/prior_box.h" +#include "base/prior_box.h" #include "nnacl/fp32/space_to_depth.h" #include "nnacl/fp32/space_to_batch.h" #include "nnacl/int8/quant_dtype_cast.h" @@ -998,7 +998,7 @@ OpParameter *PopulateSliceParameter(const mindspore::lite::PrimitiveC *primitive MS_LOG(ERROR) << "new SliceParameter failed."; return nullptr; } - auto param = reinterpret_cast(const_cast(primitive)); + auto param = reinterpret_cast(const_cast(primitive)); slice_param->op_parameter_.type_ = primitive->Type(); auto param_begin = param->GetBegin(); auto param_size = param->GetSize(); diff --git a/mindspore/lite/src/populate_parameter.h b/mindspore/lite/src/populate_parameter.h index 1ba686543d9..5397b9bb375 100644 --- a/mindspore/lite/src/populate_parameter.h +++ b/mindspore/lite/src/populate_parameter.h @@ -17,10 +17,9 @@ #ifndef MINDSPORE_LITE_SRC_POPULATE_PARAMETER_H_ #define MINDSPORE_LITE_SRC_POPULATE_PARAMETER_H_ -#include "schema/model_generated.h" - -#include "nnacl/op_base.h" #include "src/ops/primitive_c.h" +#include "schema/model_generated.h" +#include "nnacl/op_base.h" namespace mindspore::kernel { typedef OpParameter *(*PopulateParameterFunc)(const mindspore::lite::PrimitiveC *); diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index d2bd52f6cca..6a7402688c3 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -21,6 +21,7 @@ #include #include +#include "src/ops/quant_dtype_cast.h" #include "abstract/abstract_value.h" #include "mindspore/core/ir/primitive.h" #include "src/ir/tensor.h" @@ -67,7 +68,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr &me MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam"; // activation auto input_quant_params = primitive->GetInputQuantParams(); - auto node_type = (schema::PrimitiveType) primitive->Type(); + auto node_type = (schema::PrimitiveType)primitive->Type(); for (size_t i = 0; i < input_quant_params.size(); i++) { if (i >= dst_node->inputIndex.size()) { MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size() @@ -104,10 +105,17 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr &me } } } - if (dst_node->quantType != schema::QuantType_AwareTraining && - !(node_type == schema::PrimitiveType_QuantDTypeCast && - primitive->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { - tensor_output->dataType = kNumberTypeInt8; + if (dst_node->quantType == schema::QuantType_PostTraining) { + if (node_type != schema::PrimitiveType_QuantDTypeCast) { + tensor_output->dataType = kNumberTypeInt8; + } else { + MS_ASSERT(utils::isa>(primitive)); + auto primc = utils::cast>(primitive); + MS_ASSERT(primc != nullptr); + if (primc->GetDstT() != kNumberTypeFloat32) { + tensor_output->dataType = kNumberTypeInt8; + } + } } } return RET_OK; @@ -130,24 +138,24 @@ void AnfExporter::SetGraphInputIndex(const std::unique_ptr & void AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, schema::CNodeT *return_node) { - MS_ASSERT(nullptr != meta_graph); - MS_ASSERT(nullptr != return_node); - for (size_t i = 1; i < cnode->inputs().size(); i++) { - auto input_node = cnode->input(i); - if (input_node->isa()) { - auto ret = ConvertInputCNode(input_node, return_node); - if (ret != RET_OK) { - MS_LOG(ERROR) << "obtain outputs failed"; - return; - } - } else { - MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node"; + MS_ASSERT(nullptr != meta_graph); + MS_ASSERT(nullptr != return_node); + for (size_t i = 1; i < cnode->inputs().size(); i++) { + auto input_node = cnode->input(i); + if (input_node->isa()) { + auto ret = ConvertInputCNode(input_node, return_node); + if (ret != RET_OK) { + MS_LOG(ERROR) << "obtain outputs failed"; return; } + } else { + MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node"; + return; } - for (size_t i = 0; i < return_node->inputIndex.size(); ++i) { - meta_graphT->outputIndex.push_back(return_node->inputIndex[i]); - } + } + for (size_t i = 0; i < return_node->inputIndex.size(); ++i) { + meta_graphT->outputIndex.push_back(return_node->inputIndex[i]); + } } schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) { @@ -160,12 +168,8 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) { return nullptr; } auto primT = primitiveT_value->GetPrimitiveT(); - if (primT == nullptr) { - MS_LOG(ERROR) << "PrimitiveT is nullptr"; - return nullptr; - } - if (primT->value.type == schema::PrimitiveType_TupleGetItem || - primT->value.type == schema::PrimitiveType_MakeTuple) { + if (primitiveT_value->Type() == schema::PrimitiveType_TupleGetItem || + primitiveT_value->Type() == schema::PrimitiveType_MakeTuple) { continue; } RemoveIfMakeTuple(cnode); @@ -375,9 +379,9 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptrallTensors.size(); } meta_graphT->allTensors.emplace_back(msTensor); - if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) - || IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) - || IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm)) { + if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || + IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || + IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm)) { break; } } @@ -401,7 +405,7 @@ bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType if (prim == nullptr) { return false; } - return (schema::PrimitiveType) prim->Type() == type; + return (schema::PrimitiveType)prim->Type() == type; } schema::MetaGraphT *Export(const FuncGraphPtr &func_graph) { diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.cc deleted file mode 100644 index 8310366196a..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.cc +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/anf_importer/anf_populater/anf_activation_populater.h" -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" - -namespace mindspore::lite { -int AnfActivationPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - auto primitive = std::make_unique(); - auto attr = std::make_unique(); - if (prim->name() == "ReLU") { - attr->type = schema::ActivationType_RELU; - } else if (prim->name() == "Sigmoid") { - attr->type = schema::ActivationType_SIGMOID; - } else if (prim->name() == "ReLU6") { - attr->type = schema::ActivationType_RELU6; - } - - primitive->value.type = schema::PrimitiveType_Activation; - primitive->value.value = attr.release(); - MS_ASSERT(primitiveCPtr != nullptr); - primitiveCPtr->SetPrimitiveT(primitive.release()); - return 0; -} -AnfNodePopulaterRegistrar anfReLUPopulater("ReLU", new AnfActivationPopulater()); -AnfNodePopulaterRegistrar anfReLU6Populater("ReLU6", new AnfActivationPopulater()); -AnfNodePopulaterRegistrar anfSigmoidPopulater("Sigmoid", new AnfActivationPopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.h deleted file mode 100644 index f2f18f9c0d1..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H -#define MINDSPORE_ANF_ACTIVATION_PARSER_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" -#include - - -namespace mindspore::lite { -class AnfActivationPopulater : public AnfNodePopulater { - public: - AnfActivationPopulater() = default; - ~AnfActivationPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) override; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_ANF_ACTIVATION_PARSER_H diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.cc deleted file mode 100644 index 64deba38aca..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.cc +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/anf_importer/anf_populater/anf_batchnorm_populater.h" -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" - -namespace mindspore::lite { -int AnfBatchnormPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - auto primitive = std::make_unique(); - auto attr = std::make_unique(); - attr->epsilon = GetValue(prim->GetAttr("epsilon")); - primitive->value.type = schema::PrimitiveType_FusedBatchNorm; - primitive->value.value = attr.release(); - MS_ASSERT(primitiveCPtr != nullptr); - primitiveCPtr->SetPrimitiveT(primitive.release()); - return 0; -} -AnfNodePopulaterRegistrar anfBatchnormPopulater("BatchNorm", new AnfBatchnormPopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.h deleted file mode 100644 index d4c39e0b710..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.h +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H -#define MINDSPORE_ANF_BATCHNORM_PARSER_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" -#include - -namespace mindspore::lite { -class AnfBatchnormPopulater : public AnfNodePopulater { - public: - AnfBatchnormPopulater() = default; - ~AnfBatchnormPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) override; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_ANF_BATCHNORM_PARSER_H diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.cc deleted file mode 100644 index 856720578f1..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.cc +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/anf_importer/anf_populater/anf_biasadd_populater.h" -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" - -namespace mindspore::lite { -int AnfBiasAddPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - auto primitive = std::make_unique(); - auto attr = std::make_unique(); - attr->axis = {0}; - primitive->value.type = schema::PrimitiveType_BiasAdd; - primitive->value.value = attr.release(); - MS_ASSERT(primitiveCPtr != nullptr); - primitiveCPtr->SetPrimitiveT(primitive.release()); - return 0; -} - -AnfNodePopulaterRegistrar anfBiasAddPopulater("BiasAdd", new AnfBiasAddPopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.h deleted file mode 100644 index bcb3db8cc7f..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.h +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_ANF_BIASADD_PARSER_H -#define MINDSPORE_ANF_BIASADD_PARSER_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" -#include - -namespace mindspore::lite { -class AnfBiasAddPopulater : public AnfNodePopulater { - public: - AnfBiasAddPopulater() = default; - ~AnfBiasAddPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) override; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_ANF_BIASADD_PARSER_H diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.cc deleted file mode 100644 index 50bc16106e3..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/anf_importer/anf_populater/anf_concat_populater.h" -#include -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" - -namespace mindspore::lite { -int AnfConcatPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - auto primitive = std::make_unique(); - auto attr = std::make_unique(); - auto prim_axis = GetValue(prim->GetAttr("axis")); - attr->axis = prim_axis; - primitive->value.type = schema::PrimitiveType_Concat; - primitive->value.value = attr.release(); - MS_ASSERT(primitiveCPtr != nullptr); - primitiveCPtr->SetPrimitiveT(primitive.release()); - return 0; -} - -AnfNodePopulaterRegistrar anfConcatPopulater("Concat", new AnfConcatPopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc deleted file mode 100644 index cbef4d5047f..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc +++ /dev/null @@ -1,240 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/anf_importer/anf_populater/anf_conv_populater.h" -#include -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "tools/converter/quantizer/quantize_util.h" - -namespace mindspore::lite { -void AnfConvPopulater::PopulaterConv2DMultiGroup(const PrimitivePtr &prim, - const std::unique_ptr &primitive, const int &group, - const std::vector &inputs) { - auto attr = std::make_unique(); - auto format = GetValue(prim->GetAttr("data_format")); - if (format == "NCHW") { - attr->format = schema::Format_NCHW; - } else if (format == "NHWC") { - attr->format = schema::Format_NHWC; - } else { - attr->format = schema::Format_NUM_OF_FORMAT; - } - auto pad_list = GetValue>(prim->GetAttr("pad_list")); - attr->padUp = pad_list[0]; - attr->padDown = pad_list[1]; - attr->padLeft = pad_list[2]; - attr->padRight = pad_list[3]; - - auto dilation = GetValue>(prim->GetAttr("dilation")); - attr->dilateH = dilation[0]; - attr->dilateW = dilation[1]; - - auto kernel_size = GetValue>(prim->GetAttr("kernel_size")); - attr->kernelH = kernel_size[0]; - attr->kernelW = kernel_size[1]; - - auto stride = GetValue>(prim->GetAttr("stride")); - attr->strideH = stride[2]; - attr->strideW = stride[3]; - - auto pad_mode = GetValue(prim->GetAttr("pad_mode")); - if (pad_mode == "valid") { - attr->padMode = schema::PadMode_VALID; - } else if (pad_mode == "same") { - attr->padMode = schema::PadMode_SAME; - } else { - attr->padMode = schema::PadMode_NOTSET; - } - - int channel_mutiplier = 1; - if (prim->GetAttr("channel_mutiplier") != nullptr) { - channel_mutiplier = GetValue(prim->GetAttr("channel_multiplier")); - } - attr->channelMultiplier = channel_mutiplier; - - MS_ASSERT(inputs.size() == kAnfPopulaterTwo); - auto input_node = inputs[kAnfPopulaterOne]; - MS_ASSERT(input_node != nullptr); - if (input_node->isa()) { - auto param_node = input_node->cast(); - ConvertConvWeight(param_node); - } - - primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; - primitive->value.value = attr.release(); -} - -void AnfConvPopulater::PopulaterConv2DSingleGroup(const PrimitivePtr &prim, - const std::unique_ptr &primitive, - const int &group) { - auto attr = std::make_unique(); - attr->group = group; - auto format = GetValue(prim->GetAttr("data_format")); - if (format == "NCHW") { - attr->format = schema::Format_NCHW; - } else if (format == "NHWC") { - attr->format = schema::Format_NHWC; - } else { - attr->format = schema::Format_NUM_OF_FORMAT; - } - auto pad_list = GetValue>(prim->GetAttr("pad_list")); - attr->padUp = pad_list[0]; - attr->padDown = pad_list[1]; - attr->padLeft = pad_list[2]; - attr->padRight = pad_list[3]; - - auto dilation = GetValue>(prim->GetAttr("dilation")); - attr->dilateH = dilation[0]; - attr->dilateW = dilation[1]; - - auto kernel_size = GetValue>(prim->GetAttr("kernel_size")); - attr->kernelH = kernel_size[0]; - attr->kernelW = kernel_size[1]; - - auto stride = GetValue>(prim->GetAttr("stride")); - attr->strideH = stride[2]; - attr->strideW = stride[3]; - - attr->channelOut = GetValue(prim->GetAttr("out_channel")); - - auto pad_mode = GetValue(prim->GetAttr("pad_mode")); - if (pad_mode == "valid") { - attr->padMode = schema::PadMode_VALID; - } else if (pad_mode == "same") { - attr->padMode = schema::PadMode_SAME; - } else { - attr->padMode = schema::PadMode_NOTSET; - } - primitive->value.type = schema::PrimitiveType_Conv2D; - primitive->value.value = attr.release(); -} - -void AnfConvPopulater::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { - constexpr float qmin = 0; - constexpr float qmax = 255; - *mMin = static_cast((qmin - mean) / stdDev); - *mMax = static_cast((qmax - mean) / stdDev); -} - -void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim, - std::vector> *vecInputQuantParam, - std::vector> *vecOutputQuantParam) { - auto narrow_range = prim->GetAttr("narrow_range"); - bool narrowRangeQuantParam = GetValue(narrow_range); - auto num_bits = prim->GetAttr("num_bits"); - int32_t numbitsRangeQuantParam = GetValue(num_bits); - - std::vector quants; - schema::QuantParamT quantParam; - auto mean = prim->GetAttr("mean"); - auto std_dev = prim->GetAttr("std_dev"); - if (mean != nullptr && std_dev != nullptr) { - auto meanQuantOaram = GetValue(mean); - double stddevQuantOaram = GetValue(std_dev); - float mMin = 0.0; - float mMax = 0.0; - CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); - quantParam.min = mMin; - quantParam.max = mMax; - } else { - auto inputMin = prim->GetAttr("input_minq"); - auto inputMax = prim->GetAttr("input_maxq"); - auto inputMinPtr = inputMin->cast(); - auto inputMaxPtr = inputMax->cast(); - float *minBuf = static_cast(inputMinPtr->Data()); - float *maxBuf = static_cast(inputMaxPtr->Data()); - quantParam.min = *minBuf; - quantParam.max = *maxBuf; - } - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - vecInputQuantParam->emplace_back(quants); - - quants.clear(); - int biasQuantSize = 0; - auto filterMin = prim->GetAttr("filter_minq"); - auto filterMax = prim->GetAttr("filter_maxq"); - if (filterMin != nullptr && filterMax != nullptr) { - auto filterMinPtr = filterMin->cast(); - auto filterMaxPtr = filterMax->cast(); - float *minBuf = static_cast(filterMinPtr->Data()); - float *maxBuf = static_cast(filterMaxPtr->Data()); - biasQuantSize = filterMinPtr->DataSize(); - for (int i = 0; i < biasQuantSize; ++i) { - quantParam.min = *(minBuf++); - quantParam.max = *(maxBuf++); - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - } - vecInputQuantParam->emplace_back(quants); - } - - quants.clear(); - for (int i = 0; i < biasQuantSize; ++i) { - quantParam.min = 0.0; - quantParam.max = 0.0; - quantParam.zeroPoint = 0; - - quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale; - quants.emplace_back(quantParam); - } - vecInputQuantParam->emplace_back(quants); - - quants.clear(); - auto outputMin = prim->GetAttr("output_minq"); - auto outputMax = prim->GetAttr("output_maxq"); - if (outputMin != nullptr && outputMax != nullptr) { - auto outputMinPtr = outputMin->cast(); - auto outputMaxPtr = outputMax->cast(); - float *minBuf = static_cast(outputMinPtr->Data()); - float *maxBuf = static_cast(outputMaxPtr->Data()); - quantParam.min = *minBuf; - quantParam.max = *maxBuf; - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - vecOutputQuantParam->emplace_back(quants); - } -} - -int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - MS_ASSERT(primitiveCPtr != nullptr); - auto primitive = std::make_unique(); - - int group = GetValue(prim->GetAttr("group")); - if (group > 1) { - PopulaterConv2DMultiGroup(prim, primitive, group, inputs); - } else { - PopulaterConv2DSingleGroup(prim, primitive, group); - } - primitiveCPtr->SetPrimitiveT(primitive.release()); - - if (primitiveCPtr->GetQuantType() == schema::QuantType_AwareTraining) { - std::vector> vecInputQuantParam; - std::vector> vecOutputQuantParam; - PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); - primitiveCPtr->SetInputQuantParam(vecInputQuantParam); - primitiveCPtr->SetOutputQuantParam(vecOutputQuantParam); - } - return 0; -} -AnfNodePopulaterRegistrar anfConvPopulater("Conv2D", new AnfConvPopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h deleted file mode 100644 index f2c3d7b30dc..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_ANF_CONV_PARSER_H -#define MINDSPORE_ANF_CONV_PARSER_H - -#include "tools/anf_importer/anf_populater/anf_node_populater.h" -#include -#include -#include "base/base_ref.h" -#include "abstract/abstract_value.h" -#include "src/param_value_lite.h" -#include "src/ir/tensor.h" - -namespace mindspore::lite { -class AnfConvPopulater : public AnfNodePopulater { - public: - AnfConvPopulater() = default; - ~AnfConvPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) override; - - private: - template - void ConvertConvWeight(const ParameterPtr ¶m_node) { - MS_ASSERT(param_node != nullptr); - auto param = param_node->default_param(); - auto weight = std::dynamic_pointer_cast(param); - MS_ASSERT(weight != nullptr); - - std::unique_ptr buf(new (std::nothrow) T[weight->tensor_shape_size()]); - if (buf == nullptr) { - MS_LOG(ERROR) << "new buf failed"; - return; - } - - size_t filter_k = weight->tensor_shape()[0]; - size_t filter_c = weight->tensor_shape()[1]; - size_t filter_h = weight->tensor_shape()[2]; - size_t filter_w = weight->tensor_shape()[3]; - T *p1Buff = nullptr; - T *p2Buff = nullptr; - for (size_t k = 0; k < filter_k; ++k) { - for (size_t c = 0; c < filter_c; ++c) { - for (size_t h = 0; h < filter_h; ++h) { - for (size_t w = 0; w < filter_w; ++w) { - p1Buff = reinterpret_cast(weight->tensor_addr()) + - ((k * filter_c * filter_h * filter_w) + (c * filter_h * filter_w) + (h * filter_w) + (w)); - p2Buff = - buf.get() + ((c * filter_k * filter_h * filter_w) + (k * filter_h * filter_w) + (h * filter_w) + (w)); - *p2Buff = *p1Buff; - } - } - } - } - - auto ret = ::memcpy_s(weight->tensor_addr(), weight->tensor_shape_size() * sizeof(T), buf.get(), - weight->tensor_shape_size() * sizeof(T)); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy_s failed: " << ret; - return; - } - - auto abstract_base = param_node->abstract(); - MS_ASSERT(abstract_base != nullptr); - if (utils::isa(abstract_base)) { - auto abstract_tensor = utils::cast(abstract_base); - utils::cast(abstract_tensor->BuildShape())->shape()[0] = filter_c; - utils::cast(abstract_tensor->BuildShape())->shape()[1] = filter_k; - utils::cast(abstract_tensor->BuildShape())->shape()[2] = filter_h; - utils::cast(abstract_tensor->BuildShape())->shape()[3] = filter_w; - } - return; - } - - void PopulaterConv2DMultiGroup(const PrimitivePtr &prim, const std::unique_ptr &primitive, - const int &group, const std::vector &inputs); - void PopulaterConv2DSingleGroup(const PrimitivePtr &prim, const std::unique_ptr &primitive, - const int &group); - void PopulaterQuantParam(const PrimitivePtr &prim, std::vector> *vecInputQuantParam, - std::vector> *vecOutputQuantParam); - void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_ANF_CONV_PARSER_H diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc deleted file mode 100644 index abc78559c41..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc +++ /dev/null @@ -1,195 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h" -#include -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "ir/func_graph.h" -#include "src/ir/tensor.h" -#include "tools/converter/quantizer/quantize_util.h" - -namespace mindspore::lite { -void AnfDepwiseconv2DPopulater::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { - constexpr float qmin = 0; - constexpr float qmax = 255; - *mMin = static_cast((qmin - mean) / stdDev); - *mMax = static_cast((qmax - mean) / stdDev); -} - -void AnfDepwiseconv2DPopulater::PopulaterQuantParam( - const PrimitivePtr &prim, - std::vector> *vecInputQuantParam, - std::vector> *vecOutputQuantParam) { - auto narrow_range = prim->GetAttr("narrow_range"); - bool narrowRangeQuantParam = GetValue(narrow_range); - auto num_bits = prim->GetAttr("num_bits"); - int32_t numbitsRangeQuantParam = GetValue(num_bits); - - std::vector quants; - schema::QuantParamT quantParam; - auto mean = prim->GetAttr("mean"); - auto std_dev = prim->GetAttr("std_dev"); - if (mean != nullptr && std_dev != nullptr) { - auto meanQuantOaram = GetValue(mean); - double stddevQuantOaram = GetValue(std_dev); - float mMin = 0.0; - float mMax = 0.0; - CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); - quantParam.min = mMin; - quantParam.max = mMax; - } else { - auto inputMin = prim->GetAttr("input_minq"); - auto inputMax = prim->GetAttr("input_maxq"); - auto inputMinPtr = inputMin->cast(); - auto inputMaxPtr = inputMax->cast(); - float *minBuf = static_cast(inputMinPtr->Data()); - float *maxBuf = static_cast(inputMaxPtr->Data()); - quantParam.min = *minBuf; - quantParam.max = *maxBuf; - } - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - vecInputQuantParam->emplace_back(quants); - - quants.clear(); - int biasQuantSize = 0; - auto filterMin = prim->GetAttr("filter_minq"); - auto filterMax = prim->GetAttr("filter_maxq"); - if (filterMin != nullptr && filterMax != nullptr) { - auto filterMinPtr = filterMin->cast(); - auto filterMaxPtr = filterMax->cast(); - float *minBuf = static_cast(filterMinPtr->Data()); - float *maxBuf = static_cast(filterMaxPtr->Data()); - biasQuantSize = filterMinPtr->DataSize(); - for (int i = 0; i < biasQuantSize; ++i) { - quantParam.min = *(minBuf++); - quantParam.max = *(maxBuf++); - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - } - vecInputQuantParam->emplace_back(quants); - } - - quants.clear(); - for (int i = 0; i < biasQuantSize; ++i) { - quantParam.min = 0.0; - quantParam.max = 0.0; - quantParam.zeroPoint = 0; - - quantParam.scale = - vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale; - quants.emplace_back(quantParam); - } - vecInputQuantParam->emplace_back(quants); - - quants.clear(); - auto outputMin = prim->GetAttr("output_minq"); - auto outputMax = prim->GetAttr("output_maxq"); - if (outputMin != nullptr && outputMax != nullptr) { - auto outputMinPtr = outputMin->cast(); - auto outputMaxPtr = outputMax->cast(); - float *minBuf = static_cast(outputMinPtr->Data()); - float *maxBuf = static_cast(outputMaxPtr->Data()); - quantParam.min = *minBuf; - quantParam.max = *maxBuf; - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - vecOutputQuantParam->emplace_back(quants); - } -} - -int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - auto primitive = std::make_unique(); - auto attr = std::make_unique(); - - auto format = GetValue(prim->GetAttr("data_format")); - if (format == "NCHW") { - attr->format = schema::Format_NCHW; - } else if (format == "NHWC") { - attr->format = schema::Format_NHWC; - } else { - attr->format = schema::Format_NUM_OF_FORMAT; - } - auto pad_list = GetValue>(prim->GetAttr("pads")); - attr->padUp = pad_list[0]; - attr->padDown = pad_list[1]; - attr->padLeft = pad_list[2]; - attr->padRight = pad_list[3]; - - auto dilation = GetValue>(prim->GetAttr("dilation")); - attr->dilateH = dilation[0]; - attr->dilateW = dilation[1]; - - auto kernel_size = GetValue>(prim->GetAttr("kernel_size")); - attr->kernelH = kernel_size[0]; - attr->kernelW = kernel_size[1]; - - auto stride = GetValue>(prim->GetAttr("stride")); - attr->strideH = stride[2]; - attr->strideW = stride[3]; - - auto pad_mode = GetValue(prim->GetAttr("pad_mode")); - if (pad_mode == "valid") { - attr->padMode = schema::PadMode_VALID; - } else if (pad_mode == "same") { - attr->padMode = schema::PadMode_SAME; - } else { - attr->padMode = schema::PadMode_NOTSET; - } - - auto channel_multiplier = GetValue(prim->GetAttr("channel_multiplier")); - attr->channelMultiplier = channel_multiplier; - - MS_ASSERT(inputs.size() == kAnfPopulaterTwo); - auto inputNode = inputs[kAnfPopulaterOne]; - MS_ASSERT(inputNode != nullptr); - if (inputNode->isa()) { - auto paramNode = inputNode->cast(); - auto abstractBase = paramNode->abstract(); - MS_ASSERT(abstractBase != nullptr); - if (utils::isa(abstractBase)) { - auto abstractTensor = utils::cast(abstractBase); - MS_ASSERT(abstractTensor != nullptr); - if (utils::isa(abstractTensor->BuildShape())) { - auto dims = utils::cast(abstractTensor->BuildShape())->shape(); - attr->channelIn = dims[kAnfPopulaterOne]; - } - } - } - - primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; - primitive->value.value = attr.release(); - MS_ASSERT(primitiveCPtr != nullptr); - primitiveCPtr->SetPrimitiveT(primitive.release()); - - if (primitiveCPtr->GetQuantType() == schema::QuantType_AwareTraining) { - std::vector> vecInputQuantParam; - std::vector> vecOutputQuantParam; - PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); - primitiveCPtr->SetInputQuantParam(vecInputQuantParam); - primitiveCPtr->SetOutputQuantParam(vecOutputQuantParam); - } - return 0; -} -AnfNodePopulaterRegistrar anfdepthwise2dPopulater("DepthwiseConv2D", new AnfDepwiseconv2DPopulater()); -AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h deleted file mode 100644 index 110e5ae5502..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H -#define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H - -#include "tools/anf_importer/anf_populater/anf_node_populater.h" -#include - -namespace mindspore::lite { -class AnfDepwiseconv2DPopulater : public AnfNodePopulater { - public: - AnfDepwiseconv2DPopulater() = default; - ~AnfDepwiseconv2DPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) override; - - private: - void PopulaterQuantParam( - const PrimitivePtr &prim, - std::vector> *vecInputQuantParam, - std::vector> *vecOutputQuantParam); - void CalQuantParam(const double &mean, const double &stdDev, float *mMin, - float *mMax); -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.cc deleted file mode 100644 index f9d7cfaea65..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.cc +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/anf_importer/anf_populater/anf_dequant_populater.h" -#include -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" - -namespace mindspore::lite { -int AnfDequantPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - auto primitive = std::make_unique(); - auto attr = std::make_unique(); - primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; - primitive->value.value = attr.release(); - MS_ASSERT(primitiveCPtr != nullptr); - primitiveCPtr->SetPrimitiveT(primitive.release()); - return 0; -} -AnfNodePopulaterRegistrar anfDequantPopulater("Dequant", new AnfDequantPopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.cc deleted file mode 100644 index ce5645fb350..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.cc +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/anf_importer/anf_populater/anf_flatten_populater.h" -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" - -namespace mindspore::lite { -int AnfFlattenPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - auto primitive = std::make_unique(); - auto attr = std::make_unique(); - primitive->value.type = schema::PrimitiveType_Flatten; - primitive->value.value = attr.release(); - MS_ASSERT(primitiveCPtr != nullptr); - primitiveCPtr->SetPrimitiveT(primitive.release()); - return 0; -} - -AnfNodePopulaterRegistrar anfFlattenPopulater("Flatten", new AnfFlattenPopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.h deleted file mode 100644 index 01c0f0a1dd1..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.h +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_ANF_FLATTEN_PARSER_H -#define MINDSPORE_ANF_FLATTEN_PARSER_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" -#include - -namespace mindspore::lite { -class AnfFlattenPopulater : public AnfNodePopulater { - public: - AnfFlattenPopulater() = default; - ~AnfFlattenPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) override; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_ANF_FLATTEN_PARSER_H diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_make_tuple_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_make_tuple_populater.cc deleted file mode 100644 index 16d40c11039..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_make_tuple_populater.cc +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "tools/anf_importer/anf_populater/anf_make_tuple_populater.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" - -namespace mindspore::lite { -int AnfMakeTuplePopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - auto primitive = std::make_unique(); - auto attr = std::make_unique(); - primitive->value.type = schema::PrimitiveType_MakeTuple; - primitive->value.value = attr.release(); - MS_ASSERT(primitiveCPtr != nullptr); - primitiveCPtr->SetPrimitiveT(primitive.release()); - return 0; -} -AnfNodePopulaterRegistrar anfMakeTuplePopulater("make_tuple", new AnfMakeTuplePopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.cc deleted file mode 100644 index 261f0cff239..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.cc +++ /dev/null @@ -1,126 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/anf_importer/anf_populater/anf_matmul_populater.h" -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "ir/func_graph.h" -#include "src/ir/tensor.h" -#include "tools/converter/quantizer/quantize_util.h" - -namespace mindspore::lite { -void AnfMatmulPopulater::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { - constexpr float qmin = 0; - constexpr float qmax = 255; - *mMin = static_cast((qmin - mean) / stdDev); - *mMax = static_cast((qmax - mean) / stdDev); -} - -void AnfMatmulPopulater::PopulaterQuantParam( - const PrimitivePtr &prim, - std::vector> *vecInputQuantParam, - std::vector> *vecOutputQuantParam) { - auto narrow_range = prim->GetAttr("narrow_range"); - bool narrowRangeQuantParam = GetValue(narrow_range); - auto num_bits = prim->GetAttr("num_bits"); - int32_t numbitsRangeQuantParam = GetValue(num_bits); - - std::vector quants; - schema::QuantParamT quantParam; - auto mean = prim->GetAttr("mean"); - auto std_dev = prim->GetAttr("std_dev"); - if (mean != nullptr && std_dev != nullptr) { - auto meanQuantOaram = GetValue(mean); - double stddevQuantOaram = GetValue(std_dev); - float mMin = 0.0; - float mMax = 0.0; - CalQuantParam(meanQuantOaram, stddevQuantOaram, &mMin, &mMax); - quantParam.min = mMin; - quantParam.max = mMax; - } else { - auto inputMin = prim->GetAttr("input_minq"); - auto inputMax = prim->GetAttr("input_maxq"); - auto inputMinPtr = inputMin->cast(); - auto inputMaxPtr = inputMax->cast(); - float *minBuf = static_cast(inputMinPtr->Data()); - float *maxBuf = static_cast(inputMaxPtr->Data()); - quantParam.min = *minBuf; - quantParam.max = *maxBuf; - } - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - vecInputQuantParam->emplace_back(quants); - - quants.clear(); - auto filterMin = prim->GetAttr("filter_minq"); - auto filterMax = prim->GetAttr("filter_maxq"); - if (filterMin != nullptr && filterMax != nullptr) { - auto filterMinPtr = filterMin->cast(); - auto filterMaxPtr = filterMax->cast(); - float *minBuf = static_cast(filterMinPtr->Data()); - float *maxBuf = static_cast(filterMaxPtr->Data()); - for (int i = 0; i < filterMinPtr->DataSize(); ++i) { - quantParam.min = *(minBuf++); - quantParam.max = *(maxBuf++); - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - } - vecInputQuantParam->emplace_back(quants); - } - - quants.clear(); - auto outputMin = prim->GetAttr("output_minq"); - auto outputMax = prim->GetAttr("output_maxq"); - if (outputMin != nullptr && outputMax != nullptr) { - auto outputMinPtr = outputMin->cast(); - auto outputMaxPtr = outputMax->cast(); - float *minBuf = static_cast(outputMinPtr->Data()); - float *maxBuf = static_cast(outputMaxPtr->Data()); - quantParam.min = *minBuf; - quantParam.max = *maxBuf; - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, - numbitsRangeQuantParam); - quants.emplace_back(quantParam); - vecOutputQuantParam->emplace_back(quants); - } -} - -int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - auto primitive = std::make_unique(); - auto attr = std::make_unique(); - attr->transposeA = GetValue(prim->GetAttr("transpose_a")); - attr->transposeB = GetValue(prim->GetAttr("transpose_b")); - - primitive->value.type = schema::PrimitiveType_MatMul; - primitive->value.value = attr.release(); - MS_ASSERT(primitiveCPtr != nullptr); - primitiveCPtr->SetPrimitiveT(primitive.release()); - if (primitiveCPtr->GetQuantType() == schema::QuantType_AwareTraining) { - std::vector> vecInputQuantParam; - std::vector> vecOutputQuantParam; - PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); - primitiveCPtr->SetInputQuantParam(vecInputQuantParam); - primitiveCPtr->SetOutputQuantParam(vecOutputQuantParam); - } - return 0; -} -AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater()); -AnfNodePopulaterRegistrar anfMatMulPopulater("MatMul", new AnfMatmulPopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.h deleted file mode 100644 index e67b2e7d157..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_ANF_MATMUL_PARSER_H -#define MINDSPORE_ANF_MATMUL_PARSER_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" -#include - -namespace mindspore::lite { -class AnfMatmulPopulater : public AnfNodePopulater { - public: - AnfMatmulPopulater() = default; - ~AnfMatmulPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) override; - - private: - void PopulaterQuantParam( - const PrimitivePtr &prim, - std::vector> *vecInputQuantParam, - std::vector> *vecOutputQuantParam); - void CalQuantParam(const double &mean, const double &stdDev, float *mMin, - float *mMax); -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_ANF_MATMUL_PARSER_H diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.cc deleted file mode 100644 index bc99452582d..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.cc +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/anf_importer/anf_populater/anf_mul_populater.h" -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" - -namespace mindspore::lite { -int AnfMulPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - auto primitive = std::make_unique(); - auto attr = std::make_unique(); - primitive->value.type = schema::PrimitiveType_Mul; - primitive->value.value = attr.release(); - MS_ASSERT(primitiveCPtr != nullptr); - primitiveCPtr->SetPrimitiveT(primitive.release()); - return 0; -} -AnfNodePopulaterRegistrar anfMulPopulater("Mul", new AnfMulPopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.h deleted file mode 100644 index 1dad59e86ba..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.h +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H -#define MINDSPORE_ANF_ACTIVATION_PARSER_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" -#include - -namespace mindspore::lite { -class AnfMulPopulater : public AnfNodePopulater { - public: - AnfMulPopulater() = default; - ~AnfMulPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) override; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_ANF_ACTIVATION_PARSER_H diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater.cc deleted file mode 100644 index 609a73b7ef3..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater.cc +++ /dev/null @@ -1,19 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/anf_importer/anf_populater/anf_node_populater.h" - -namespace mindspore::lite {} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater.h deleted file mode 100644 index d0290fe6964..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_ANF_NODE_PARSER_H -#define MINDSPORE_ANF_NODE_PARSER_H - -#include -#include "ir/anf.h" -#include "schema/inner/model_generated.h" -#include "src/ops/primitive_c.h" - -namespace mindspore::lite { -constexpr int kAnfPopulaterOne = 1; -constexpr int kAnfPopulaterTwo = 2; -constexpr int kAnfPopulaterThree = 3; -class AnfNodePopulater { - public: - AnfNodePopulater() = default; - virtual ~AnfNodePopulater() = default; - - virtual int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) = 0; -}; - -} // namespace mindspore::lite - -#endif // MINDSPORE_ANF_NODE_PARSER_H diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.cc deleted file mode 100644 index 0515bb52f37..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.cc +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include -namespace mindspore { -namespace lite { -AnfNodePopulaterRegistry::~AnfNodePopulaterRegistry() { - for (auto ite : populaters) { - if (ite.second != nullptr) { - delete ite.second; - ite.second = nullptr; - } - } -} -AnfNodePopulaterRegistry *AnfNodePopulaterRegistry::GetInstance() { - static AnfNodePopulaterRegistry instance; - return &instance; -} -AnfNodePopulater *AnfNodePopulaterRegistry::GetNodePopulater(const std::string &name) { - if (populaters.find(name) == populaters.end()) { - return nullptr; - } - return populaters[name]; -} -void AnfNodePopulaterRegistry::SetNodePopulater(const std::string &name, AnfNodePopulater *populater) { - populaters[name] = populater; -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.h deleted file mode 100644 index 2f7b984fe29..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_ANF_NODE_PARSER_REGISTRY_H -#define MINDSPORE_ANF_NODE_PARSER_REGISTRY_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" -#include -#include -namespace mindspore::lite { -class AnfNodePopulaterRegistry { - public: - AnfNodePopulaterRegistry() = default; - virtual ~AnfNodePopulaterRegistry(); - static AnfNodePopulaterRegistry *GetInstance(); - AnfNodePopulater *GetNodePopulater(const std::string &name); - void SetNodePopulater(const std::string &name, AnfNodePopulater *populater); - - private: - std::unordered_map populaters; -}; - -class AnfNodePopulaterRegistrar { - public: - AnfNodePopulaterRegistrar(const std::string &name, AnfNodePopulater *populater) { - AnfNodePopulaterRegistry::GetInstance()->SetNodePopulater(name, populater); - } -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_ANF_NODE_PARSER_REGISTRY_H diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.cc deleted file mode 100644 index 1e2cc41a345..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.cc +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/anf_importer/anf_populater/anf_pool_populater.h" -#include -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" - -namespace mindspore::lite { -int AnfPoolPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - auto primitive = std::make_unique(); - auto attr = std::make_unique(); - if (prim->instance_name() == "MaxPool") { - attr->poolingMode = schema::PoolMode_MAX_POOLING; - } else if (prim->instance_name() == "MeanPool") { - attr->poolingMode = schema::PoolMode_MEAN_POOLING; - } - - auto format = GetValue(prim->GetAttr("data_format")); - if (format == "NCHW") { - attr->format = schema::Format_NCHW; - } else if (format == "NHWC") { - attr->format = schema::Format_NHWC; - } else { - attr->format = schema::Format_NUM_OF_FORMAT; - } - - auto pad_mode = GetValue(prim->GetAttr("padding")); - if (pad_mode == "VALID") { - attr->padMode = schema::PadMode_VALID; - } else if (pad_mode == "SAME") { - attr->padMode = schema::PadMode_SAME; - } else { - attr->padMode = schema::PadMode_NOTSET; - } - - auto kernel_size = GetValue>(prim->GetAttr("ksize")); - attr->windowH = kernel_size[2]; - attr->windowW = kernel_size[3]; - - auto stride = GetValue>(prim->GetAttr("strides")); - attr->strideH = stride[2]; - attr->strideW = stride[3]; - - primitive->value.type = schema::PrimitiveType_Pooling; - primitive->value.value = attr.release(); - MS_ASSERT(primitiveCPtr != nullptr); - primitiveCPtr->SetPrimitiveT(primitive.release()); - return 0; -} -AnfNodePopulaterRegistrar anfMaxPoolPopulater("MaxPool", new AnfPoolPopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.h deleted file mode 100644 index d2228dc45d0..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.h +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_ANF_POOL_PARSER_H -#define MINDSPORE_ANF_POOL_PARSER_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" -#include -namespace mindspore::lite { -class AnfPoolPopulater : public AnfNodePopulater { - public: - AnfPoolPopulater() = default; - ~AnfPoolPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) override; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_ANF_POOL_PARSER_H diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.cc deleted file mode 100644 index 69c0d514fac..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.cc +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/anf_importer/anf_populater/anf_quant_populater.h" -#include -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" - -namespace mindspore::lite { -int AnfQuantPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - auto primitive = std::make_unique(); - auto attr = std::make_unique(); - primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize; - primitive->value.value = attr.release(); - MS_ASSERT(primitiveCPtr != nullptr); - primitiveCPtr->SetPrimitiveT(primitive.release()); - return 0; -} -AnfNodePopulaterRegistrar anfQuantPopulater("Quant", new AnfQuantPopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.cc deleted file mode 100644 index a0818db075f..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.cc +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/anf_importer/anf_populater/anf_reducemean_populater.h" -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" - -namespace mindspore::lite { -int AnfReduceMeanPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - auto primitive = std::make_unique(); - auto attr = std::make_unique(); - attr->mode = schema::ReduceMode_ReduceMean; - - attr->keepDims = GetValue(prim->GetAttr("keep_dims")); - if (inputs.size() == kAnfPopulaterTwo) { - auto inputNode = inputs[kAnfPopulaterOne]; - MS_ASSERT(inputNode != nullptr); - if (inputNode->isa()) { - auto valueNode = inputNode->cast(); - MS_ASSERT(valueNode != nullptr); - auto value = valueNode->value(); - MS_ASSERT(value != nullptr); - if (value->isa()) { - auto valTuplPtr = dyn_cast(value); - MS_ASSERT(valTuplPtr != nullptr); - for (size_t i = 0; i < valTuplPtr->size(); i++) { - auto elem = dyn_cast((*valTuplPtr)[i]); - MS_ASSERT(elem != nullptr); - attr->axes.emplace_back(elem->value()); - } - } - } - } - - primitive->value.type = schema::PrimitiveType_Reduce; - primitive->value.value = attr.release(); - MS_ASSERT(primitiveCPtr != nullptr); - primitiveCPtr->SetPrimitiveT(primitive.release()); - return 0; -} -AnfNodePopulaterRegistrar anfReduceMeanPopulater("ReduceMean", new AnfReduceMeanPopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.h deleted file mode 100644 index 15e07c08c72..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.h +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H -#define MINDSPORE_ANF_ACTIVATION_PARSER_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" -#include -namespace mindspore::lite { -class AnfReduceMeanPopulater : public AnfNodePopulater { - public: - AnfReduceMeanPopulater() = default; - ~AnfReduceMeanPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) override; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_ANF_ACTIVATION_PARSER_H diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.cc deleted file mode 100644 index 1252c068416..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/anf_importer/anf_populater/anf_reshape_populater.h" -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" - -namespace mindspore::lite { -int AnfReshapePopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - auto primitive = std::make_unique(); - auto attr = std::make_unique(); - MS_ASSERT(inputs.size() == kAnfPopulaterThree - 1); - auto inputNode = inputs[kAnfPopulaterTwo - 1]; - if (inputNode->isa()) { - auto valueNode = inputNode->cast(); - MS_ASSERT(valueNode != nullptr); - auto val = valueNode->value(); - MS_ASSERT(val != nullptr); - if (val->isa()) { - auto tuple = val->cast(); - MS_ASSERT(tuple != nullptr); - for (size_t i = 0; i < tuple->size(); ++i) { - auto elem = tuple->value()[i]->cast(); - MS_ASSERT(elem != nullptr); - attr->shape.emplace_back(static_cast(elem->value())); - } - } - } - - primitive->value.type = schema::PrimitiveType_Reshape; - primitive->value.value = attr.release(); - MS_ASSERT(primitiveCPtr != nullptr); - primitiveCPtr->SetPrimitiveT(primitive.release()); - return 0; -} - -AnfNodePopulaterRegistrar anfReshapePopulater("Reshape", new AnfReshapePopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.h deleted file mode 100644 index a0ce7513a19..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.h +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_ANF_RESHAPE_PARSER_H -#define MINDSPORE_ANF_RESHAPE_PARSER_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" -#include -namespace mindspore::lite { -class AnfReshapePopulater : public AnfNodePopulater { - public: - AnfReshapePopulater() = default; - ~AnfReshapePopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) override; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_ANF_RESHAPE_PARSER_H diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.cc deleted file mode 100644 index 428ac6d404e..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.cc +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/anf_importer/anf_populater/anf_tensoradd_populater.h" -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" - -namespace mindspore::lite { -int AnfTensorAddPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - auto primitive = std::make_unique(); - auto attr = std::make_unique(); - primitive->value.type = schema::PrimitiveType_Add; - primitive->value.value = attr.release(); - MS_ASSERT(primitiveCPtr != nullptr); - primitiveCPtr->SetPrimitiveT(primitive.release()); - return 0; -} -AnfNodePopulaterRegistrar anfTensorAddPopulater("TensorAdd", new AnfTensorAddPopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.h deleted file mode 100644 index e681cef48e6..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.h +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H -#define MINDSPORE_ANF_ACTIVATION_PARSER_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" -#include -namespace mindspore::lite { -class AnfTensorAddPopulater : public AnfNodePopulater { - public: - AnfTensorAddPopulater() = default; - ~AnfTensorAddPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) override; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_ANF_ACTIVATION_PARSER_H diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.cc deleted file mode 100644 index 4e6c1d65c85..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/anf_importer/anf_populater/anf_transpose_populater.h" -#include -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" - -namespace mindspore::lite { -int AnfTransposePopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - auto primitive = std::make_unique(); - auto attr = std::make_unique(); - MS_ASSERT(inputs.size() == kAnfPopulaterTwo); - auto inputNode = inputs[kAnfPopulaterOne]; - if (inputNode->isa()) { - auto valNode = inputNode->cast(); - MS_ASSERT(valNode != nullptr); - auto val = valNode->value(); - MS_ASSERT(val != nullptr); - if (val->isa()) { - auto tuple = val->cast(); - MS_ASSERT(tuple != nullptr); - for (size_t i = 0; i < tuple->size(); i++) { - auto elem = tuple->value()[i]->cast(); - MS_ASSERT(elem != nullptr); - attr->perm.emplace_back(static_cast(elem->value())); - } - } - } - - primitive->value.type = schema::PrimitiveType_Transpose; - primitive->value.value = attr.release(); - MS_ASSERT(primitiveCPtr != nullptr); - primitiveCPtr->SetPrimitiveT(primitive.release()); - return 0; -} -AnfNodePopulaterRegistrar anfTransposePopulater("Transpose", new AnfTransposePopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.h deleted file mode 100644 index 00699307acc..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.h +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_ANF_TRANSPOSE_PARSER_H -#define MINDSPORE_ANF_TRANSPOSE_PARSER_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" -#include -namespace mindspore::lite { -class AnfTransposePopulater : public AnfNodePopulater { - public: - AnfTransposePopulater() = default; - ~AnfTransposePopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) override; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_ANF_TRANSPOSE_PARSER_H diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.cc deleted file mode 100644 index d903fe41239..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.cc +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/anf_importer/anf_populater/anf_tuple_getitem_populater.h" -#include -#include -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" - -namespace mindspore::lite { -int AnfTupleGetItemPopulater::Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) { - auto primitive = std::make_unique(); - auto attr = std::make_unique(); - primitive->value.type = schema::PrimitiveType_TupleGetItem; - primitive->value.value = attr.release(); - MS_ASSERT(primitiveCPtr != nullptr); - primitiveCPtr->SetPrimitiveT(primitive.release()); - return 0; -} -AnfNodePopulaterRegistrar anfTupleGetItemPopulater("tuple_getitem", new AnfTupleGetItemPopulater()); -} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.h deleted file mode 100644 index caf670e67e7..00000000000 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.h +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_TUPLE_GETITEM_PARSER_H -#define MINDSPORE_TUPLE_GETITEM_PARSER_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" -#include -namespace mindspore::lite { -class AnfTupleGetItemPopulater : public AnfNodePopulater { - public: - AnfTupleGetItemPopulater() = default; - ~AnfTupleGetItemPopulater() override = default; - int Populate(const PrimitivePtr &prim, PrimitiveC *primitiveCPtr, - const std::vector &inputs) override; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_ANF_BATCHNORM_PARSER_H diff --git a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc index fe4d64eb93a..6d7c6ed8835 100644 --- a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc +++ b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc @@ -19,7 +19,7 @@ #include "schema/inner/model_generated.h" #include "frontend/operator/ops.h" #include "src/param_value_lite.h" -#include "import_from_meta_graphT.h" +#include "tools/anf_importer/import_from_meta_graphT.h" #include "utils/log_adapter.h" #include "include/errorcode.h" @@ -80,7 +80,7 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() { ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr &cNode) { MS_ASSERT(nullptr != meta_graph_); MS_ASSERT(nullptr != cNode); - auto primitiveCValue = std::make_shared(cNode->primitive.release()); + auto primitiveCValue = PrimitiveC::UnPackFromSchemaPrimitiveT(cNode->primitive.release()); cNode->primitive = nullptr; // add quant parameter if (cNode->quantType == schema::QuantType_AwareTraining) { @@ -98,7 +98,7 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr(primitiveCValue)); return value_node; } diff --git a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc index bee80a53476..e033296d21a 100644 --- a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc +++ b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc @@ -20,14 +20,12 @@ #include #include -#include #include #include #include -#include #include #include - +#include "src/ops/primitive_c.h" #include "frontend/operator/ops.h" #include "google/protobuf/io/zero_copy_stream_impl.h" #include "include/errorcode.h" @@ -39,7 +37,6 @@ #include "src/param_value_lite.h" #include "tools/converter/parser/onnx/onnx.pb.h" #include "utils/log_adapter.h" -#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" using string = std::string; using int32 = int32_t; @@ -60,16 +57,16 @@ enum ParseForm : int { }; static std::map kParseTypeSwitchMap{ - {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; + {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; static std::unordered_map kDefaultValueSwitchMap{ - {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, - {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, - {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, - {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, - {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, - {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, - {onnx::TensorProto_DataType_STRING, kObjectTypeString}, + {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, + {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, + {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, + {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, + {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, + {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, + {onnx::TensorProto_DataType_STRING, kObjectTypeString}, }; #define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ @@ -230,7 +227,8 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &pr auto value = prim->GetAttr(attr_name); break; } - default:MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; + default: + MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; return false; } return true; @@ -293,7 +291,8 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con case FORM_PARSE_TENSOR: { return ObtainCNodeAttrInTensorForm(prim, attr_name, attr_tensor); } - default:MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; + default: + MS_LOG(ERROR) << "parse attr type don't support input of ref_attr_name"; return false; } } @@ -357,7 +356,8 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &val value_ptr = std::make_shared(elems); break; } - default:MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; + default: + MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; return false; } auto new_value_node = NewValueNode(value_ptr); @@ -395,7 +395,8 @@ bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_at case FORM_PARSE_TYPE: { return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); } - default:MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name"; + default: + MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name"; return false; } } @@ -472,18 +473,12 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out } inputs.push_back(anfnode_build_map_[input_name]); } - std::string opType = prim->name(); - auto node_parser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); - if (node_parser == nullptr) { - MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; + auto primitivec_ptr = PrimitiveC::UnPackFromPrimitive(*prim, inputs); + if (primitivec_ptr == nullptr) { + MS_LOG(ERROR) << "Create PrimitiveC return nullptr, " << prim->name(); return nullptr; } - auto primitiveT = std::make_unique(); - std::shared_ptr primitiveCPtr = std::make_shared(primitiveT.release()); - primitiveCPtr->SetQuantType(quantType); - node_parser->Populate(prim, primitiveCPtr.get(), inputs); - MS_ASSERT(primitiveCPtr != nullptr); - inputs.insert(inputs.begin(), NewValueNode(primitiveCPtr)); + inputs.insert(inputs.begin(), NewValueNode(primitivec_ptr)); CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(cnode_ptr); if (node_type == "LayerNorm") { @@ -521,9 +516,9 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output auto primitiveT = std::make_unique(); MS_ASSERT(primitiveT != nullptr); primitiveT->value.type = schema::PrimitiveType_MakeTuple; - std::shared_ptr primitiveCPtr = std::make_shared(primitiveT.release()); - MS_ASSERT(primitiveCPtr != nullptr); - inputs.push_back(NewValueNode(primitiveCPtr)); + std::shared_ptr primitivec_ptr = std::make_shared(primitiveT.release()); + MS_ASSERT(primitivec_ptr != nullptr); + inputs.push_back(NewValueNode(primitivec_ptr)); AbstractBasePtrList elem; for (int out_size = 0; out_size < importProto.output_size(); ++out_size) { const onnx::ValueInfoProto &output_node = importProto.output(out_size); @@ -537,9 +532,9 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output auto primReturn = std::make_unique(); MS_ASSERT(primReturn != nullptr); primReturn->value.type = schema::PrimitiveType_Return; - std::shared_ptr primitiveTReturnValuePtr = std::make_shared(primReturn.release()); - MS_ASSERT(primitiveTReturnValuePtr != nullptr); - inputs.push_back(NewValueNode(primitiveTReturnValuePtr)); + std::shared_ptr primitive_return_value_ptr = std::make_shared(primReturn.release()); + MS_ASSERT(primitive_return_value_ptr != nullptr); + inputs.push_back(NewValueNode(primitive_return_value_ptr)); inputs.push_back(maketuple_ptr); auto return_node = outputFuncGraph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(return_node); @@ -679,7 +674,7 @@ onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string delete onnx_model; return nullptr; } - (void) close(fd); + (void)close(fd); MS_LOG(INFO) << "enter ReadProtoFromBinary success!" << std::endl; return onnx_model; } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc index b326f44b0a8..73e103d63b9 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc @@ -316,34 +316,7 @@ FusionPass::~FusionPass() { } void FusionPass::MergeNodeAttrFromPost(std::unique_ptr &dstOp, std::unique_ptr &postOp, - size_t dstOpOutIdx) { - // // merge quantParam - // if (dstOp->quantParam.empty()) { // not awareing quant - // return; - // } - // MS_ASSERT(postOp->outputIndex.size() == 1); - // if (dstOp->quantParam.size() != dstOp->inputIndex.size() + dstOp->outputIndex.size()) { - // int a = 1; - // } - // MS_ASSERT(dstOp->quantParam.size() == dstOp->inputIndex.size() + dstOp->outputIndex.size()); - // auto &dstQuantParamArray = dstOp->quantParam.at(dstOp->inputIndex.size() + dstOpOutIdx); - // auto &postQuantParamArray = postOp->quantParam.back(); - // if (!(postQuantParamArray != nullptr && postQuantParamArray->param.size() == 1 && - // postQuantParamArray->param.front() != nullptr && postQuantParamArray->param.front()->min != FLT_MAX)) { - // return; // postNode has no quantParam, no need merge - // } - // - // if ((dstQuantParamArray != nullptr && dstQuantParamArray->param.size() != 1) || - // (dstQuantParamArray->param.front() != nullptr && dstQuantParamArray->param.front()->min != FLT_MAX)) { - // return; // dstNode has quantParam, no need merge - // } - // - // dstQuantParamArray->param.front()->min = postQuantParamArray->param.front()->min; - // dstQuantParamArray->param.front()->max = postQuantParamArray->param.front()->max; - // dstQuantParamArray->param.front()->scale = postQuantParamArray->param.front()->scale; - // dstQuantParamArray->param.front()->zeroPoint = postQuantParamArray->param.front()->zeroPoint; - // MS_LOGD("merge quantParam from %s to %s", postOp->name.c_str(), dstOp->name.c_str()); -} + size_t dstOpOutIdx) {} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index dac9b3e5787..824ef52768e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -35,10 +35,12 @@ std::unique_ptr TfliteModelParser::ReadTfliteModel(const char *m auto buf = ReadFile(model_path, &size); if (buf == nullptr) { MS_LOG(ERROR) << "the file buffer is nullptr"; + return nullptr; } flatbuffers::Verifier verify((const uint8_t *)buf, size); if (!tflite::VerifyModelBuffer(verify)) { MS_LOG(ERROR) << "the buffer is invalid and fail to create graph"; + return nullptr; } return tflite::UnPackModel(buf); } diff --git a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc index e8aec082d5a..7922dbe432c 100644 --- a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc +++ b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc @@ -21,9 +21,9 @@ #include #include "tools/common/graph_util.h" #include "tools/common/tensor_util.h" -#include "tools/converter/quantizer/quantize_util.h" #include "schema/inner/ops_generated.h" #include "src/common/utils.h" +#include "tools/converter/quantizer/quantize_util.h" namespace mindspore::lite { STATUS QuantParamCalcer::ComputeConstQuantParam(const schema::TensorT &tensor, QuantParamT *quantParam) { diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index 2823338f81f..7a95fc90cc8 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -21,6 +21,7 @@ #include #include #include +#include "tools/converter/quantizer/quantizer.h" #include "src/ops/primitive_c.h" #include "include/errorcode.h" #include "ir/func_graph.h" @@ -29,7 +30,6 @@ #include "base/base.h" #include "ir/primitive.h" #include "abstract/dshape.h" -#include "mindspore/lite/tools/converter/quantizer/quantizer.h" namespace mindspore { namespace lite { @@ -59,13 +59,13 @@ class QuantStrategy { static const std::array mMulTypes; }; -STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, - bool narrowRange, int quant_max, int quant_min, int num_bits); +STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int quant_max, + int quant_min, int num_bits); -STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, - bool narrowRange = false, int numBits = UINT8_QUANTIZATION); +STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange = false, + int numBits = UINT8_QUANTIZATION); -template +template T QuantizeData(const float originData, const schema::QuantParamT *quantParam) { MS_ASSERT(quantParam != nullptr); MS_ASSERT(quantParam->inited); @@ -73,7 +73,7 @@ T QuantizeData(const float originData, const schema::QuantParamT *quantParam) { const auto zeroPoint = quantParam->zeroPoint; const auto numBit = quantParam->numBits; const auto narrowRange = quantParam->narrowRange; - const double maxLimit = static_cast((1 << (unsigned int) numBit) - 1 - zeroPoint) * scale; + const double maxLimit = static_cast((1 << (unsigned int)numBit) - 1 - zeroPoint) * scale; double minLimit; if (narrowRange) { minLimit = static_cast(1 - zeroPoint) * scale; @@ -97,7 +97,7 @@ T QuantizeData(const float originData, const schema::QuantParamT *quantParam) { }(); } -template +template T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quant_max, int quant_min) { MS_ASSERT(quantParam != nullptr); MS_ASSERT(quantParam->inited); diff --git a/mindspore/lite/tools/converter/quantizer/quantizer.h b/mindspore/lite/tools/converter/quantizer/quantizer.h index dbf93af6fab..963a9635527 100644 --- a/mindspore/lite/tools/converter/quantizer/quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/quantizer.h @@ -20,12 +20,12 @@ #include #include #include +#include "schema/inner/model_generated.h" #include "include/errorcode.h" #include "ir/func_graph.h" #include "ir/anf.h" #include "base/base.h" #include "src/param_value_lite.h" -#include "schema/inner/model_generated.h" #include "tools/converter/converter_flags.h" namespace mindspore::lite::quant { diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index babf93d3dc1..5043df1c92c 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -44,7 +44,7 @@ const std::vector GetCNodeInputTensors(const CNodePtr &CNode) { auto tensorT = tmp_meta_graph->allTensors.at(input_index).get(); auto tensor_shape = tensorT->dims; auto lite_tensor = - new(std::nothrow)Tensor(TypeId(tensorT->dataType), tensor_shape, tensorT->format, tensorT->nodeType); + new (std::nothrow) Tensor(TypeId(tensorT->dataType), tensor_shape, tensorT->format, tensorT->nodeType); if (lite_tensor == nullptr) { MS_LOG(ERROR) << "lite tensor is nullptr"; return input_tensors; @@ -116,8 +116,8 @@ const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *ten auto ret = memcpy_s(tensor_data, size * sizeof(float), tensor->Data(), size * sizeof(float)); if (ret != EOK) { delete tensor_data; - MS_LOG(EXCEPTION) << "memcpy error: " << ret; - return parameter; + MS_LOG(ERROR) << "memcpy error: " << ret; + return nullptr; } param_value->set_tensor_addr(tensor_data); param_value->set_tensor_size(size * sizeof(float) / sizeof(uint8_t)); @@ -171,31 +171,37 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An auto input_tensors = GetCNodeInputTensors(input_cnode); if (input_tensors.empty() || input_tensors.size() != input_cnode->inputs().size() - 1) { FreeInputTensor(&input_tensors); - return any_node; + continue; } MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope(); auto output_nums = GetOutputTensorNum(input_cnode); std::vector output_tensors{output_nums, new Tensor()}; auto scheam_primitive = PackPrimitiveT(input_cnode); - auto lite_primitive = mindspore::lite::PrimitiveC::CreatePrimitive(scheam_primitive); + auto lite_primitive = mindspore::lite::PrimitiveC::UnPackFromSchemaPrimitive(scheam_primitive); if (lite_primitive == nullptr) { - MS_LOG(DEBUG) << "constant_folding schedule node lite primitive nullptr"; + MS_LOG(ERROR) << "constant_folding schedule node lite primitive nullptr"; FreeInputTensor(&input_tensors); return nullptr; } lite_primitive->InferShape(input_tensors, output_tensors); auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive); if (lite_kernel == nullptr) { - MS_LOG(DEBUG) << "constant_folding schedule node lite kernel nullptr"; + MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; FreeInputTensor(&input_tensors); return nullptr; } auto ret = lite_kernel->Run(); if (0 != ret) { FreeInputTensor(&input_tensors); - MS_LOG(EXCEPTION) << "run kernel failed, name: " << lite_kernel->name(); + MS_LOG(ERROR) << "run kernel failed, name: " << lite_kernel->name(); + return nullptr; } auto new_parameter = CreateNewParamter(func_graph, output_tensors.front()); + if (new_parameter == nullptr) { + FreeInputTensor(&input_tensors); + MS_LOG(ERROR) << "CreateNewParamter failed, name: " << lite_kernel->name(); + return nullptr; + } new_parameter->set_name(input_node->fullname_with_scope()); any_node->set_input(i, new_parameter); FreeInputTensor(&input_tensors); diff --git a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc index 496b73406c5..99ac179907d 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc @@ -16,11 +16,14 @@ #include "tools/optimizer/fusion/conv_activation_fusion.h" #include -#include "mindspore/lite/src/ops/activation.h" #include "src/ops/primitive_c.h" +#include "src/ops/conv2d.h" +#include "src/ops/depthwise_conv2d.h" +#include "src/ops/activation.h" #include "schema/inner/model_generated.h" #include "tools/optimizer/common/gllo_utils.h" + namespace mindspore::opt { namespace { constexpr size_t kActivationInputsLength = 2; @@ -44,8 +47,11 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c CheckIfCNodeIsNull(act_node); CheckInputSize(act_node, kActivationInputsLength); - auto act_primitive = GetValueNode>(act_node->input(0)); - if (act_primitive->GetPrimitiveT()->value.AsActivation()->type != activation_type) { + auto primitivec = GetValueNode>(act_node->input(0)); + MS_ASSERT(utils::isa>(primitivec)); + auto act_primitivec = utils::cast>(primitivec); + MS_ASSERT(act_primitivec != nullptr); + if (act_primitivec->GetType() != activation_type) { return node; } AnfNodePtr pre_node = act_node->input(1); @@ -59,10 +65,16 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c auto primitiveT_value = GetValueNode>(conv_node->input(0)); MS_ASSERT(primitiveT_value); if (node_type == schema::PrimitiveType_Conv2D) { - primitiveT_value->GetPrimitiveT()->value.AsConv2D()->activationType = activation_type; + MS_ASSERT(utils::isa>(primitiveT_value)); + auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(primc != nullptr); + primc->SetActivationType(activation_type); return pre_node; } else if (node_type == schema::PrimitiveType_DepthwiseConv2D) { - primitiveT_value->GetPrimitiveT()->value.AsDepthwiseConv2D()->activationType = activation_type; + MS_ASSERT(utils::isa>(primitiveT_value)); + auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(primc != nullptr); + primc->SetActivationType(activation_type); return pre_node; } else { MS_LOG(EXCEPTION) << "conv activation pass match only conv2d or depthwise_conv2d "; diff --git a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc index 8f24c0fa785..f746240552f 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc @@ -15,6 +15,9 @@ */ #include "tools/optimizer/fusion/conv_biasadd_fusion.h" #include +#include "src/ops/conv2d.h" +#include "src/ops/depthwise_conv2d.h" +#include "src/ops/deconv2d.h" #include "src/ops/primitive_c.h" #include "src/param_value_lite.h" #include "schema/inner/model_generated.h" @@ -57,12 +60,20 @@ int Get_Kenrnel_nums(const CNodePtr &conv_node) { MS_ASSERT(primitive != nullptr); auto type = (schema::PrimitiveType)primitive->Type(); if (type == schema::PrimitiveType_Conv2D) { - return primitive->GetPrimitiveT()->value.AsConv2D()->channelOut; + MS_ASSERT(utils::isa>(primitive)); + auto primc = utils::cast>(primitive); + MS_ASSERT(primc != nullptr); + return primc->GetChannelOut(); } else if (type == schema::PrimitiveType_DepthwiseConv2D) { - return primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelMultiplier * - primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelIn; + MS_ASSERT(utils::isa>(primitive)); + auto primc = utils::cast>(primitive); + MS_ASSERT(primc != nullptr); + return primc->GetChannelMultiplier() * primc->GetChannelIn(); } else if (type == schema::PrimitiveType_DeConv2D) { - return primitive->GetPrimitiveT()->value.AsDeConv2D()->channelOut; + MS_ASSERT(utils::isa>(primitive)); + auto primc = utils::cast>(primitive); + MS_ASSERT(primc != nullptr); + return primc->GetChannelOut(); } else { MS_LOG(ERROR) << "Unsupported opType, " << type; return 0; @@ -83,7 +94,7 @@ void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, c if (kernel_nums <= 0) { MS_LOG(EXCEPTION) << "kernel num less than 0"; } - auto add_bias_data = new(std::nothrow) float[kernel_nums]; + auto add_bias_data = new (std::nothrow) float[kernel_nums]; if (add_bias_data == nullptr) { MS_LOG(ERROR) << "tensor_data is nullptr"; return; @@ -151,13 +162,22 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons GenConvNewBias(func_graph, conv_node, add_node); auto primitiveT_value = GetValueNode>(conv_node->input(0)); MS_ASSERT(primitiveT_value != nullptr); - auto type = primitiveT_value->GetPrimitiveT()->value.type; + auto type = primitiveT_value->Type(); if (type == schema::PrimitiveType_Conv2D) { - primitiveT_value->GetPrimitiveT()->value.AsConv2D()->hasBias = true; + MS_ASSERT(utils::isa>(primitiveT_value)); + auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(primc != nullptr); + primc->SetHasBias(true); } else if (type == schema::PrimitiveType_DepthwiseConv2D) { - primitiveT_value->GetPrimitiveT()->value.AsDepthwiseConv2D()->hasBias = true; + MS_ASSERT(utils::isa>(primitiveT_value)); + auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(primc != nullptr); + primc->SetHasBias(true); } else if (type == schema::PrimitiveType_DeConv2D) { - primitiveT_value->GetPrimitiveT()->value.AsDeConv2D()->hasBias = true; + MS_ASSERT(utils::isa>(primitiveT_value)); + auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(primc != nullptr); + primc->SetHasBias(true); } else { MS_LOG(EXCEPTION) << "Unsupported opType, " << type; } diff --git a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc index db6ee61ad9b..9256a60a635 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc @@ -17,6 +17,8 @@ #include "tools/optimizer/fusion/conv_transform_fusion.h" #include #include "src/ops/primitive_c.h" +#include "src/ops/conv2d.h" +#include "src/ops/depthwise_conv2d.h" #include "src/param_value_lite.h" #include "schema/inner/model_generated.h" #include "tools/optimizer/common/gllo_utils.h" @@ -38,12 +40,18 @@ int Get_Kenrnel_nums(const CNodePtr &conv_node) { MS_ASSERT(value != nullptr); auto primitive = value->cast(); MS_ASSERT(primitive != nullptr); - auto type = primitive->GetPrimitiveT()->value.type; + auto type = (schema::PrimitiveType)primitive->Type(); + if (type == schema::PrimitiveType_Conv2D) { - return primitive->GetPrimitiveT()->value.AsConv2D()->channelOut; + MS_ASSERT(utils::isa>(primitive)); + auto primc = utils::cast>(primitive); + MS_ASSERT(primc != nullptr); + return primc->GetChannelOut(); } else if (type == schema::PrimitiveType_DepthwiseConv2D) { - return primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelMultiplier * - primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelIn; + MS_ASSERT(utils::isa>(primitive)); + auto primc = utils::cast>(primitive); + MS_ASSERT(primc != nullptr); + return primc->GetChannelMultiplier() * primc->GetChannelIn(); } else { MS_LOG(ERROR) << "Unsupported opType, " << type; return 0; @@ -74,12 +82,12 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co MS_LOG(INFO) << "Unsupported conv node, " << conv_node->DebugString(); return node; } - auto trans_scale = new(std::nothrow) float[kernel_nums]; + auto trans_scale = new (std::nothrow) float[kernel_nums]; if (trans_scale == nullptr) { MS_LOG(ERROR) << "tensor_data is nullptr"; return nullptr; } - auto trans_bias = new(std::nothrow) float[kernel_nums]; + auto trans_bias = new (std::nothrow) float[kernel_nums]; if (trans_bias == nullptr) { MS_LOG(ERROR) << "tensor_data is nullptr"; delete trans_scale; @@ -91,11 +99,17 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co delete[] trans_scale; auto primitiveT_value = GetValueNode>(conv_node->input(0)); MS_ASSERT(primitiveT_value != nullptr); - auto type = primitiveT_value->GetPrimitiveT()->value.type; + auto type = primitiveT_value->Type(); if (type == schema::PrimitiveType_Conv2D) { - primitiveT_value->GetPrimitiveT()->value.AsConv2D()->hasBias = true; + MS_ASSERT(utils::isa>(primitiveT_value)); + auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(primc != nullptr); + primc->SetHasBias(true); } else if (type == schema::PrimitiveType_DepthwiseConv2D) { - primitiveT_value->GetPrimitiveT()->value.AsDepthwiseConv2D()->hasBias = true; + MS_ASSERT(utils::isa>(primitiveT_value)); + auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(primc != nullptr); + primc->SetHasBias(true); } else { MS_LOG(EXCEPTION) << "Unsupported opType, " << type; } @@ -170,7 +184,7 @@ const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, bias_data = reinterpret_cast(bias_tensor->tensor_addr()); bias_flag = true; } else { - bias_data = new(std::nothrow) float[kernel_num]; + bias_data = new (std::nothrow) float[kernel_num]; if (bias_data == nullptr) { MS_LOG(ERROR) << "tensor_data is nullptr"; return; @@ -186,7 +200,7 @@ const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, int kernel_size, const float *trans_scale) const { MS_ASSERT(weight_data != nullptr); - auto tmp_weight_data = new(std::nothrow) float[kernel_num * kernel_size]; + auto tmp_weight_data = new (std::nothrow) float[kernel_num * kernel_size]; MS_ASSERT(new_weight_data != nullptr); auto data_size = kernel_num * kernel_size * sizeof(float); if (0 != memset_s(tmp_weight_data, data_size, 0, data_size)) { @@ -212,7 +226,7 @@ const void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_nu const float *trans_scale, const float *trans_bias) const { MS_ASSERT(bias_data != nullptr); if (bias_flag) { - auto tmp_bias_data = new(std::nothrow) float[kernel_num]; + auto tmp_bias_data = new (std::nothrow) float[kernel_num]; if (tmp_bias_data == nullptr) { MS_LOG(ERROR) << "tensor_data is nullptr"; return;